From 3f116ad6772e104df575b6acd23486a96425082e Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Wed, 15 May 2024 14:17:23 +0200 Subject: [PATCH] default to fast base (#1588) --- demo/backend/pytorch.py | 6 +++--- demo/backend/tensorflow.py | 6 +++--- doctr/models/detection/zoo.py | 2 +- doctr/models/zoo.py | 4 ++-- scripts/analyze.py | 2 +- scripts/detect_text.py | 2 +- tests/pytorch/test_models_zoo_pt.py | 12 ++++++------ tests/tensorflow/test_models_zoo_tf.py | 8 ++++---- 8 files changed, 21 insertions(+), 21 deletions(-) diff --git a/demo/backend/pytorch.py b/demo/backend/pytorch.py index 21da01d22b..9ce8532b2f 100644 --- a/demo/backend/pytorch.py +++ b/demo/backend/pytorch.py @@ -10,15 +10,15 @@ from doctr.models.predictor import OCRPredictor DET_ARCHS = [ + "fast_base", + "fast_small", + "fast_tiny", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", - "fast_tiny", - "fast_small", - "fast_base", ] RECO_ARCHS = [ "crnn_vgg16_bn", diff --git a/demo/backend/tensorflow.py b/demo/backend/tensorflow.py index c38a193f19..980ae628d8 100644 --- a/demo/backend/tensorflow.py +++ b/demo/backend/tensorflow.py @@ -10,14 +10,14 @@ from doctr.models.predictor import OCRPredictor DET_ARCHS = [ + "fast_base", + "fast_small", + "fast_tiny", "db_resnet50", "db_mobilenet_v3_large", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", - "fast_tiny", - "fast_small", - "fast_base", ] RECO_ARCHS = [ "crnn_vgg16_bn", diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index 45cbc1adc5..3cab59e381 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -75,7 +75,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, def detection_predictor( - arch: Any = "db_resnet50", + arch: Any = "fast_base", pretrained: bool = False, assume_straight_pages: bool = True, **kwargs: Any, diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index a351589037..eff5fe14c4 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -61,7 +61,7 @@ def _predictor( def ocr_predictor( - det_arch: Any = "db_resnet50", + det_arch: Any = "fast_base", reco_arch: Any = "crnn_vgg16_bn", pretrained: bool = False, pretrained_backbone: bool = True, @@ -175,7 +175,7 @@ def _kie_predictor( def kie_predictor( - det_arch: Any = "db_resnet50", + det_arch: Any = "fast_base", reco_arch: Any = "crnn_vgg16_bn", pretrained: bool = False, pretrained_backbone: bool = True, diff --git a/scripts/analyze.py b/scripts/analyze.py index 2f64175a8e..94415267a2 100644 --- a/scripts/analyze.py +++ b/scripts/analyze.py @@ -43,7 +43,7 @@ def parse_args(): ) parser.add_argument("path", type=str, help="Path to the input document (PDF or image)") - parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis") + parser.add_argument("--detection", type=str, default="fast_base", help="Text detection model to use for analysis") parser.add_argument( "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis" ) diff --git a/scripts/detect_text.py b/scripts/detect_text.py index 573080c59b..f65b6685df 100644 --- a/scripts/detect_text.py +++ b/scripts/detect_text.py @@ -85,7 +85,7 @@ def parse_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("path", type=str, help="Path to process: PDF, image, directory") - parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis") + parser.add_argument("--detection", type=str, default="fast_base", help="Text detection model to use for analysis") parser.add_argument("--bin-thresh", type=float, default=0.3, help="Binarization threshold for the detection model.") parser.add_argument("--box-thresh", type=float, default=0.1, help="Threshold for the detection boxes.") parser.add_argument( diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 5bcd10ee62..fe67321668 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -83,7 +83,7 @@ def test_trained_ocr_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True, @@ -111,7 +111,7 @@ def test_trained_ocr_predictor(mock_payslip): assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised, rtol=0.05) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True, @@ -196,7 +196,7 @@ def test_trained_kie_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True, @@ -222,12 +222,12 @@ def test_trained_kie_predictor(mock_payslip): geometry_mr = np.array([[0.1083984375, 0.0634765625], [0.1494140625, 0.0859375]]) assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr, rtol=0.05) - assert out.pages[0].predictions[CLASS_NAME][4].value == "revised" + assert out.pages[0].predictions[CLASS_NAME][3].value == "revised" geometry_revised = np.array([[0.7548828125, 0.126953125], [0.8388671875, 0.1484375]]) - assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][4].geometry), geometry_revised, rtol=0.05) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][3].geometry), geometry_revised, rtol=0.05) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True, diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 906d6d0f5d..f0c73a9bef 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -82,7 +82,7 @@ def test_trained_ocr_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True, @@ -112,7 +112,7 @@ def test_trained_ocr_predictor(mock_payslip): assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised, rtol=0.05) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True, @@ -194,7 +194,7 @@ def test_trained_kie_predictor(mock_payslip): doc = DocumentFile.from_images(mock_payslip) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True, @@ -225,7 +225,7 @@ def test_trained_kie_predictor(mock_payslip): assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][3].geometry), geometry_revised, rtol=0.05) det_predictor = detection_predictor( - "db_resnet50", + "fast_base", pretrained=True, batch_size=2, assume_straight_pages=True,