test-driver: Factor out OCR related code to machine/ocr.py
This commit is contained in:
@@ -13,8 +13,8 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import _GeneratorContextManager, nullcontext
|
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -22,6 +22,7 @@ from typing import Any
|
|||||||
from test_driver.errors import MachineError, RequestedAssertionFailed
|
from test_driver.errors import MachineError, RequestedAssertionFailed
|
||||||
from test_driver.logger import AbstractLogger
|
from test_driver.logger import AbstractLogger
|
||||||
|
|
||||||
|
from .ocr import perform_ocr_on_screenshot, perform_ocr_variants_on_screenshot
|
||||||
from .qmp import QMPSession
|
from .qmp import QMPSession
|
||||||
|
|
||||||
CHAR_TO_KEY = {
|
CHAR_TO_KEY = {
|
||||||
@@ -92,84 +93,6 @@ def make_command(args: list) -> str:
|
|||||||
return " ".join(map(shlex.quote, (map(str, args))))
|
return " ".join(map(shlex.quote, (map(str, args))))
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
|
|
||||||
magick_args = [
|
|
||||||
"-filter",
|
|
||||||
"Catrom",
|
|
||||||
"-density",
|
|
||||||
"72",
|
|
||||||
"-resample",
|
|
||||||
"300",
|
|
||||||
"-contrast",
|
|
||||||
"-normalize",
|
|
||||||
"-despeckle",
|
|
||||||
"-type",
|
|
||||||
"grayscale",
|
|
||||||
"-sharpen",
|
|
||||||
"1",
|
|
||||||
"-posterize",
|
|
||||||
"3",
|
|
||||||
]
|
|
||||||
out_file = screenshot_path
|
|
||||||
|
|
||||||
if negate:
|
|
||||||
magick_args.append("-negate")
|
|
||||||
out_file += ".negative"
|
|
||||||
|
|
||||||
magick_args += [
|
|
||||||
"-gamma",
|
|
||||||
"100",
|
|
||||||
"-blur",
|
|
||||||
"1x65535",
|
|
||||||
]
|
|
||||||
out_file += ".png"
|
|
||||||
|
|
||||||
ret = subprocess.run(
|
|
||||||
["magick", "convert"] + magick_args + [screenshot_path, out_file],
|
|
||||||
capture_output=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ret.returncode != 0:
|
|
||||||
raise MachineError(
|
|
||||||
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return out_file
|
|
||||||
|
|
||||||
|
|
||||||
def _perform_ocr_on_screenshot(
|
|
||||||
screenshot_path: str, model_ids: Iterable[int]
|
|
||||||
) -> list[str]:
|
|
||||||
if shutil.which("tesseract") is None:
|
|
||||||
raise MachineError("OCR requested but enableOCR is false")
|
|
||||||
|
|
||||||
processed_image = _preprocess_screenshot(screenshot_path, negate=False)
|
|
||||||
processed_negative = _preprocess_screenshot(screenshot_path, negate=True)
|
|
||||||
|
|
||||||
model_results = []
|
|
||||||
for image in [screenshot_path, processed_image, processed_negative]:
|
|
||||||
for model_id in model_ids:
|
|
||||||
ret = subprocess.run(
|
|
||||||
[
|
|
||||||
"tesseract",
|
|
||||||
image,
|
|
||||||
"-",
|
|
||||||
"--oem",
|
|
||||||
str(model_id),
|
|
||||||
"-c",
|
|
||||||
"debug_file=/dev/null",
|
|
||||||
"--psm",
|
|
||||||
"11",
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
)
|
|
||||||
if ret.returncode != 0:
|
|
||||||
raise MachineError(f"OCR failed with exit code {ret.returncode}")
|
|
||||||
model_results.append(ret.stdout.decode("utf-8"))
|
|
||||||
|
|
||||||
return model_results
|
|
||||||
|
|
||||||
|
|
||||||
def retry(fn: Callable, timeout: int = 900) -> None:
|
def retry(fn: Callable, timeout: int = 900) -> None:
|
||||||
"""Call the given function repeatedly, with 1 second intervals,
|
"""Call the given function repeatedly, with 1 second intervals,
|
||||||
until it returns True or a timeout is reached.
|
until it returns True or a timeout is reached.
|
||||||
@@ -910,6 +833,17 @@ class Machine:
|
|||||||
self.log(f"(connecting took {toc - tic:.2f} seconds)")
|
self.log(f"(connecting took {toc - tic:.2f} seconds)")
|
||||||
self.connected = True
|
self.connected = True
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _managed_screenshot(self) -> Generator[str]:
|
||||||
|
"""
|
||||||
|
Take a screenshot and yield the screenshot filepath.
|
||||||
|
The file will be deleted when leaving the generator.
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
screenshot_path: str = os.path.join(tmpdir, "ppm")
|
||||||
|
self.send_monitor_command(f"screendump {screenshot_path}")
|
||||||
|
yield screenshot_path
|
||||||
|
|
||||||
def screenshot(self, filename: str) -> None:
|
def screenshot(self, filename: str) -> None:
|
||||||
"""
|
"""
|
||||||
Take a picture of the display of the virtual machine, in PNG format.
|
Take a picture of the display of the virtual machine, in PNG format.
|
||||||
@@ -919,17 +853,19 @@ class Machine:
|
|||||||
filename += ".png"
|
filename += ".png"
|
||||||
if "/" not in filename:
|
if "/" not in filename:
|
||||||
filename = os.path.join(self.out_dir, filename)
|
filename = os.path.join(self.out_dir, filename)
|
||||||
tmp = f"{filename}.ppm"
|
|
||||||
|
|
||||||
with self.nested(
|
with self.nested(
|
||||||
f"making screenshot {filename}",
|
f"making screenshot {filename}",
|
||||||
{"image": os.path.basename(filename)},
|
{"image": os.path.basename(filename)},
|
||||||
):
|
):
|
||||||
self.send_monitor_command(f"screendump {tmp}")
|
with self._managed_screenshot() as screenshot_path:
|
||||||
ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
|
ret = subprocess.run(
|
||||||
os.unlink(tmp)
|
f"pnmtopng '{screenshot_path}' > '{filename}'", shell=True
|
||||||
if ret.returncode != 0:
|
)
|
||||||
raise MachineError("Cannot convert screenshot")
|
if ret.returncode != 0:
|
||||||
|
raise MachineError(
|
||||||
|
f"Cannot convert screenshot (pnmtopng returned code {ret.returncode})"
|
||||||
|
)
|
||||||
|
|
||||||
def copy_from_host_via_shell(self, source: str, target: str) -> None:
|
def copy_from_host_via_shell(self, source: str, target: str) -> None:
|
||||||
"""Copy a file from the host into the guest by piping it over the
|
"""Copy a file from the host into the guest by piping it over the
|
||||||
@@ -1003,12 +939,6 @@ class Machine:
|
|||||||
"""Debugging: Dump the contents of the TTY<n>"""
|
"""Debugging: Dump the contents of the TTY<n>"""
|
||||||
self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat")
|
self.execute(f"fold -w 80 /dev/vcs{tty} | systemd-cat")
|
||||||
|
|
||||||
def _get_screen_text_variants(self, model_ids: Iterable[int]) -> list[str]:
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
screenshot_path = os.path.join(tmpdir, "ppm")
|
|
||||||
self.send_monitor_command(f"screendump {screenshot_path}")
|
|
||||||
return _perform_ocr_on_screenshot(screenshot_path, model_ids)
|
|
||||||
|
|
||||||
def get_screen_text_variants(self) -> list[str]:
|
def get_screen_text_variants(self) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of different interpretations of what is currently
|
Return a list of different interpretations of what is currently
|
||||||
@@ -1021,7 +951,8 @@ class Machine:
|
|||||||
This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
|
This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
|
||||||
:::
|
:::
|
||||||
"""
|
"""
|
||||||
return self._get_screen_text_variants([0, 1, 2])
|
with self._managed_screenshot() as screenshot_path:
|
||||||
|
return perform_ocr_variants_on_screenshot(screenshot_path)
|
||||||
|
|
||||||
def get_screen_text(self) -> str:
|
def get_screen_text(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -1032,7 +963,8 @@ class Machine:
|
|||||||
This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
|
This requires [`enableOCR`](#test-opt-enableOCR) to be set to `true`.
|
||||||
:::
|
:::
|
||||||
"""
|
"""
|
||||||
return self._get_screen_text_variants([2])[0]
|
with self._managed_screenshot() as screenshot_path:
|
||||||
|
return perform_ocr_on_screenshot(screenshot_path)
|
||||||
|
|
||||||
def wait_for_text(self, regex: str, timeout: int = 900) -> None:
|
def wait_for_text(self, regex: str, timeout: int = 900) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
111
nixos/lib/test-driver/src/test_driver/machine/ocr.py
Normal file
111
nixos/lib/test-driver/src/test_driver/machine/ocr.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from test_driver.errors import MachineError
|
||||||
|
|
||||||
|
|
||||||
|
def perform_ocr_on_screenshot(screenshot_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Perform OCR on a screenshot that contains text.
|
||||||
|
Returns a string with all words that could be found.
|
||||||
|
"""
|
||||||
|
return perform_ocr_variants_on_screenshot(screenshot_path, False)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def perform_ocr_variants_on_screenshot(
|
||||||
|
screenshot_path: str, variants: bool = True
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Same as perform_ocr_on_screenshot but will create variants of the images
|
||||||
|
that can lead to more words being detected.
|
||||||
|
Returns a string with words for each variant.
|
||||||
|
"""
|
||||||
|
if shutil.which("tesseract") is None:
|
||||||
|
raise MachineError("OCR requested but `tesseract` is not available")
|
||||||
|
|
||||||
|
# tesseract --help-oem
|
||||||
|
# OCR Engine modes (OEM):
|
||||||
|
# 0|tesseract_only Legacy engine only.
|
||||||
|
# 1|lstm_only Neural nets LSTM engine only.
|
||||||
|
# 2|tesseract_lstm_combined Legacy + LSTM engines.
|
||||||
|
# 3|default Default, based on what is available.
|
||||||
|
model_ids: list[int] = [0, 1, 2] if variants else [3]
|
||||||
|
|
||||||
|
image_paths = [
|
||||||
|
screenshot_path,
|
||||||
|
_preprocess_screenshot(screenshot_path, negate=False),
|
||||||
|
_preprocess_screenshot(screenshot_path, negate=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
def run_tesseract(image: str, model_id: int) -> str:
|
||||||
|
ret = subprocess.run(
|
||||||
|
[
|
||||||
|
"tesseract",
|
||||||
|
image,
|
||||||
|
"-",
|
||||||
|
"--oem",
|
||||||
|
str(model_id),
|
||||||
|
"-c",
|
||||||
|
"debug_file=/dev/null",
|
||||||
|
"--psm",
|
||||||
|
"11",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
)
|
||||||
|
if ret.returncode != 0:
|
||||||
|
raise MachineError(f"OCR failed with exit code {ret.returncode}")
|
||||||
|
return ret.stdout.decode("utf-8")
|
||||||
|
|
||||||
|
return [
|
||||||
|
run_tesseract(image, model_id)
|
||||||
|
for image in image_paths
|
||||||
|
for model_id in model_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
|
||||||
|
if shutil.which("magick") is None:
|
||||||
|
raise MachineError("OCR requested but `magick` is not available")
|
||||||
|
|
||||||
|
magick_args = [
|
||||||
|
"-filter",
|
||||||
|
"Catrom",
|
||||||
|
"-density",
|
||||||
|
"72",
|
||||||
|
"-resample",
|
||||||
|
"300",
|
||||||
|
"-contrast",
|
||||||
|
"-normalize",
|
||||||
|
"-despeckle",
|
||||||
|
"-type",
|
||||||
|
"grayscale",
|
||||||
|
"-sharpen",
|
||||||
|
"1",
|
||||||
|
"-posterize",
|
||||||
|
"3",
|
||||||
|
]
|
||||||
|
out_file = screenshot_path
|
||||||
|
|
||||||
|
if negate:
|
||||||
|
magick_args.append("-negate")
|
||||||
|
out_file += ".negative"
|
||||||
|
|
||||||
|
magick_args += [
|
||||||
|
"-gamma",
|
||||||
|
"100",
|
||||||
|
"-blur",
|
||||||
|
"1x65535",
|
||||||
|
]
|
||||||
|
out_file += ".png"
|
||||||
|
|
||||||
|
ret = subprocess.run(
|
||||||
|
["magick", "convert"] + magick_args + [screenshot_path, out_file],
|
||||||
|
capture_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret.returncode != 0:
|
||||||
|
raise MachineError(
|
||||||
|
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return out_file
|
||||||
Reference in New Issue
Block a user