nv-ingest-api
nim
cached
logger = logging.getLogger(__name__)
module-attribute
CachedModelInterface
Bases: ModelInterface
An interface for handling inference with a Cached model, supporting both gRPC and HTTP protocols, including batched input.
_extract_content_from_nim_response(json_response)
Extract content from the JSON response of a NIM (HTTP) API request.
Parameters
json_response : dict of str -> Any The JSON response from the NIM API.
Returns
Any The extracted content from the response.
Raises
RuntimeError If the response format is unexpected (missing 'data' or empty).
format_input(data, protocol, max_batch_size, **kwargs)
Format input data for the specified protocol ("grpc" or "http"), handling batched images.
Parameters
data : dict of str -> Any The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray). protocol : str The protocol to use, "grpc" or "http". max_batch_size : int The maximum number of images per batch.
Returns
Any A list of formatted input batches. For gRPC, each batch is a NumPy array of shape (B, H, W, C) where B <= max_batch_size. For HTTP, each batch is a JSON-serializable dict containing base64‑encoded images.
Raises
KeyError If "image_arrays" is missing in the data dictionary. ValueError If the protocol is invalid, or if images have differing shapes for gRPC.
name()
Get the name of the model interface.
Returns
str The name of the model interface ("Cached").
parse_output(response, protocol, data=None, **kwargs)
Parse the output from the Cached model's inference response.
Parameters
response : Any The raw response from the model inference. protocol : str The protocol used ("grpc" or "http"). data : dict of str -> Any, optional Additional input data (unused here, but available for consistency). **kwargs : Any Additional keyword arguments for future compatibility.
Returns
Any The parsed output data (e.g., list of strings), depending on the protocol.
Raises
ValueError If the protocol is invalid. RuntimeError If the HTTP response is not as expected (missing 'data' key).
prepare_data_for_inference(data)
Decode base64-encoded images into NumPy arrays, storing them in data["image_arrays"].
Parameters
data : dict of str -> Any The input data containing either: - "base64_image": a single base64-encoded image, or - "base64_images": a list of base64-encoded images.
Returns
dict of str -> Any The updated data dictionary with decoded image arrays stored in "image_arrays", where each array has shape (H, W, C).
Raises
KeyError If neither 'base64_image' nor 'base64_images' is provided. ValueError If 'base64_images' is provided but is not a list.
process_inference_results(output, protocol, **kwargs)
Process inference results for the Cached model.
Parameters
output : Any The raw output from the model. protocol : str The inference protocol used ("grpc" or "http"). **kwargs : Any Additional parameters for post-processing (not used here).
Returns
Any The processed inference results, which here is simply returned as-is.
decorators
global_cache = manager.dict()
module-attribute
lock = Lock()
module-attribute
logger = logging.getLogger(__name__)
module-attribute
manager = Manager()
module-attribute
multiprocessing_cache(max_calls)
A decorator that creates a global cache shared between multiple processes.
The cache is invalidated after max_calls number of accesses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
max_calls
|
int
|
The number of calls after which the cache is cleared. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
function |
The decorated function with global cache and invalidation logic. |
deplot
logger = logging.getLogger(__name__)
module-attribute
DeplotModelInterface
Bases: ModelInterface
An interface for handling inference with a Deplot model, supporting both gRPC and HTTP protocols, now updated to handle multiple base64 images ('base64_images').
_extract_content_from_deplot_response(json_response)
staticmethod
Extract content from the JSON response of a Deplot HTTP API request. The original code expected a single choice with a single textual content.
_prepare_deplot_payload(base64_list, max_tokens=500, temperature=0.5, top_p=0.9)
staticmethod
Prepare an HTTP payload for Deplot that includes one message per image, matching the original single-image style:
messages = [
{
"role": "user",
"content": "Generate ... <img src="data:image/png;base64,..." />"
},
{
"role": "user",
"content": "Generate ... <img src="data:image/png;base64,..." />"
},
...
]
If your backend expects multiple messages in a single request, this keeps the same structure as the single-image code repeated N times.
format_input(data, protocol, max_batch_size, **kwargs)
Format input data for the specified protocol (gRPC or HTTP). For HTTP, we now construct multiple messages—one per image batch—in the same style as the original single-image code.
Parameters
data : dict of str -> Any The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray). protocol : str The protocol to use, "grpc" or "http". max_batch_size : int The maximum number of images per batch. kwargs : dict Additional parameters to pass to the payload preparation (for HTTP).
Returns
Any For gRPC: A list of NumPy arrays, each of shape (B, H, W, C) with B <= max_batch_size. For HTTP: A list of JSON-serializable payload dicts, each containing up to max_batch_size images.
Raises
KeyError If "image_arrays" is missing in the data dictionary. ValueError If the protocol is invalid, or if images have differing shapes for gRPC.
name()
Get the name of the model interface.
Returns
str The name of the model interface ("Deplot").
parse_output(response, protocol, data=None, **kwargs)
Parse the model's inference response.
prepare_data_for_inference(data)
Prepare input data by decoding one or more base64-encoded images into NumPy arrays.
Parameters
data : dict The input data containing either 'base64_image' (single image) or 'base64_images' (multiple images).
Returns
dict The updated data dictionary with 'image_arrays': a list of decoded NumPy arrays.
process_inference_results(output, protocol, **kwargs)
Process inference results for the Deplot model.
Parameters
output : Any The raw output from the model. protocol : str The protocol used for inference (gRPC or HTTP).
Returns
Any The processed inference results.
doughnut
ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES
module-attribute
ACCEPTED_IMAGE_CLASSES = set(['Picture'])
module-attribute
ACCEPTED_TABLE_CLASSES = set(['Table'])
module-attribute
ACCEPTED_TEXT_CLASSES = set(['Text', 'Title', 'Section-header', 'List-item', 'TOC', 'Bibliography', 'Formula', 'Page-header', 'Page-footer', 'Caption', 'Footnote', 'Floating-text'])
module-attribute
_re_extract_class_bbox = re.compile('<x_(\\d+)><y_(\\d+)>((?:|.(?:(?<!<x_\\d)(?<!<y_\\d)(?<!<class_[A-Za-z0-9]).)*))<x_(\\d+)><y_(\\d+)><class_([A-Za-z0-9\\-]+)>', re.MULTILINE | re.DOTALL)
module-attribute
logger = logging.getLogger(__name__)
module-attribute
_fix_dots(m)
extract_classes_bboxes(text)
postprocess_text(txt, cls)
reverse_transform_bbox(bbox, bbox_offset, original_width, original_height)
strip_markdown_formatting(text)
helpers
DEPLOT_MAX_TOKENS = 128
module-attribute
DEPLOT_TEMPERATURE = 1.0
module-attribute
DEPLOT_TOP_P = 1.0
module-attribute
logger = logging.getLogger(__name__)
module-attribute
ModelInterface
Base class for defining a model interface that supports preparing input data, formatting it for inference, parsing output, and processing inference results.
format_input(data, protocol, max_batch_size)
Format the input data for the specified protocol.
Parameters
data : dict The input data to format. protocol : str The protocol to format the data for.
name()
Get the name of the model interface.
Returns
str The name of the model interface.
parse_output(response, protocol, data=None, **kwargs)
Parse the output data from the model's inference response.
Parameters
response : Any The response from the model inference. protocol : str The protocol used ("grpc" or "http"). data : dict, optional Additional input data passed to the function.
prepare_data_for_inference(data)
Prepare input data for inference by processing or transforming it as required.
Parameters
data : dict The input data to prepare.
process_inference_results(output_array, protocol, **kwargs)
Process the inference results from the model.
Parameters
output_array : Any The raw output from the model. kwargs : dict Additional parameters for processing.
NimClient
A client for interfacing with a model inference server using gRPC or HTTP protocols.
_lock = threading.Lock()
instance-attribute
_max_batch_sizes = {}
instance-attribute
auth_token = auth_token
instance-attribute
client = None
instance-attribute
endpoint_url = generate_url(self._http_endpoint)
instance-attribute
headers = {'accept': 'application/json', 'content-type': 'application/json'}
instance-attribute
max_retries = max_retries
instance-attribute
model_interface = model_interface
instance-attribute
protocol = protocol.lower()
instance-attribute
timeout = timeout
instance-attribute
__init__(model_interface, protocol, endpoints, auth_token=None, timeout=120.0, max_retries=5)
Initialize the NimClient with the specified model interface, protocol, and server endpoints.
Parameters
model_interface : ModelInterface The model interface implementation to use. protocol : str The protocol to use ("grpc" or "http"). endpoints : tuple A tuple containing the gRPC and HTTP endpoints. auth_token : str, optional Authorization token for HTTP requests (default: None). timeout : float, optional Timeout for HTTP requests in seconds (default: 30.0).
Raises
ValueError If an invalid protocol is specified or if required endpoints are missing.
_fetch_max_batch_size(model_name, model_version='')
Fetch the maximum batch size from the Triton model configuration in a thread-safe manner.
_grpc_infer(formatted_input, model_name)
Perform inference using the gRPC protocol.
Parameters
formatted_input : np.ndarray The input data formatted as a numpy array. model_name : str The name of the model to use for inference.
Returns
np.ndarray The output of the model as a numpy array.
_http_infer(formatted_input)
Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times.
Parameters
formatted_input : dict The input data formatted as a dictionary.
Returns
dict The output of the model as a dictionary.
Raises
TimeoutError If the HTTP request times out repeatedly, up to the max retries. requests.RequestException For other HTTP-related errors that persist after max retries.
_process_batch(batch_input, *, prepared_data, model_name, **kwargs)
Process a single batch input for inference.
Parameters
batch_input : Any The batch input data to process. prepared_data : Any The prepared data used for inference. model_name : str The model name to use for inference. kwargs : dict Additional parameters for inference.
Returns
Any The parsed output from the inference request.
close()
infer(data, model_name, **kwargs)
Perform inference using the specified model and input data.
Parameters
data : dict The input data for inference. model_name : str The name of the model to use for inference. kwargs : dict Additional parameters for inference. Optionally supports "max_pool_workers" to set the number of worker threads in the thread pool.
Returns
Any The processed inference results.
Raises
ValueError If an invalid protocol is specified.
try_set_max_batch_size(model_name, model_version='')
Attempt to set the max batch size for the model if it is not already set, ensuring thread safety.
create_inference_client(endpoints, model_interface, auth_token=None, infer_protocol=None)
Create a NimClient for interfacing with a model inference server.
Parameters
endpoints : tuple A tuple containing the gRPC and HTTP endpoints. model_interface : ModelInterface The model interface implementation to use. auth_token : str, optional Authorization token for HTTP requests (default: None). infer_protocol : str, optional The protocol to use ("grpc" or "http"). If not specified, it is inferred from the endpoints.
Returns
NimClient The initialized NimClient.
Raises
ValueError If an invalid infer_protocol is specified.
generate_url(url)
Examines the user defined URL for http*://. If that
pattern is detected the URL is used as provided by the user.
If that pattern does not exist then the assumption is made that
the endpoint is simply http:// and that is prepended
to the user supplied endpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
url
|
str
|
Endpoint where the Rest service is running |
required |
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
Fully validated URL |
get_version(http_endpoint, metadata_endpoint='/v1/metadata', version_field='version')
Get the version of the server from its metadata endpoint.
Parameters
http_endpoint : str The HTTP endpoint of the server. metadata_endpoint : str, optional The metadata endpoint to query (default: "/v1/metadata"). version_field : str, optional The field containing the version in the response (default: "version").
Returns
str The version of the server, or an empty string if unavailable.
is_ready(http_endpoint, ready_endpoint)
Check if the server at the given endpoint is ready.
Parameters
http_endpoint : str The HTTP endpoint of the server. ready_endpoint : str The specific ready-check endpoint.
Returns
bool True if the server is ready, False otherwise.
preprocess_image_for_paddle(array, paddle_version=None)
Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding, and transposing it into the required format.
This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC. It is not necessary when using the HTTP endpoint.
Steps:
- Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels.
- Normalizes the image using the
normalize_imagefunction. - Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR.
- Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR.
Parameters:
array : np.ndarray The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
Returns:
np.ndarray A preprocessed image with the shape (channels, height, width) and normalized pixel values. The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0.
Notes:
- The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio.
- After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is a requirement for PaddleOCR.
- The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
remove_url_endpoints(url)
Some configurations provide the full endpoint in the URL. Ex: http://deplot:8000/v1/chat/completions. For hitting the health endpoint we need to get just the hostname:port combo that we can append the health/ready endpoint to so we attempt to parse that information here.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
url
|
str
|
Incoming URL |
required |
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
URL with just the hostname:port portion remaining |
paddle
logger = logging.getLogger(__name__)
module-attribute
PaddleOCRModelInterface
Bases: ModelInterface
An interface for handling inference with a PaddleOCR model, supporting both gRPC and HTTP protocols.
paddle_version = paddle_version
instance-attribute
__init__(paddle_version=None)
Initialize the PaddleOCR model interface.
Parameters
paddle_version : str, optional The version of the PaddleOCR model (default is None).
_convert_paddle_response_to_psuedo_markdown(bounding_boxes, text_predictions, img_index=0, dims=None)
staticmethod
Convert bounding boxes & text to pseudo-markdown format. For multiple images,
the correct image dimensions (height, width) are retrieved from dims[img_index].
Parameters
bounding_boxes : list of Any A list (per line of text) of bounding boxes, each a list of (x, y) points. text_predictions : list of str A list of text predictions, one for each bounding box. img_index : int, optional The index of the image for which bounding boxes are being converted. Default is 0. dims : list of (int, int), optional A list of (height, width) for each corresponding image.
Returns
str The pseudo-markdown representation of detected text lines and bounding boxes. Each cluster of text is placed on its own line, with text columns separated by '|'.
Notes
- If
dimsis None orimg_indexis out of range, bounding boxes will not be scaled properly.
_extract_content_from_paddle_grpc_response(response, table_content_format, dims)
Parse a gRPC response for one or more images. The response can have two possible shapes: - (3,) for batch_size=1 - (3, n) for batch_size=n
In either case
response[0, i]: byte string containing bounding box data response[1, i]: byte string containing text prediction data response[2, i]: (Optional) additional data/metadata (ignored here)
Parameters
response : np.ndarray The raw NumPy array from gRPC. Expected shape: (3,) or (3, n). table_content_format : str The format of the output text content, e.g. 'simple' or 'pseudo_markdown'. dims : list of (int, int), optional A list of (height, width) for each corresponding image, used for bounding box scaling.
Returns
list of (str, str) A list of (content, table_content_format) for each image.
Raises
ValueError
If the response is not a NumPy array or has an unexpected shape,
or if the table_content_format is unrecognized.
_extract_content_from_paddle_http_response(json_response, table_content_format, dims)
Extract content from the JSON response of a PaddleOCR HTTP API request.
Parameters
json_response : dict of str -> Any The JSON response returned by the PaddleOCR endpoint. table_content_format : str or None The specified format for table content (e.g., 'simple' or 'pseudo_markdown'). dims : list of (int, int), optional A list of (height, width) for each corresponding image, used for bounding box scaling if not None.
Returns
list of (str, str) A list of (content, table_content_format) tuples, one for each image result.
Raises
RuntimeError
If the response format is missing or invalid.
ValueError
If the table_content_format is unrecognized.
_is_version_early_access_legacy_api()
Determine if the current PaddleOCR version is considered "early access" and thus uses the legacy API format.
Returns
bool True if the version is < 0.2.1-rc2; False otherwise.
_prepare_paddle_payload(base64_img)
DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls.
Parameters
base64_img : str A single base64-encoded image string.
Returns
dict of str -> Any The payload in either legacy or new format for PaddleOCR's HTTP endpoint.
format_input(data, protocol, max_batch_size, **kwargs)
Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
Parameters
data : dict of str -> Any The input data dictionary, expected to contain "image_arrays" (list of np.ndarray). protocol : str The inference protocol, either "grpc" or "http". max_batch_size : int The maximum batch size batching.
Returns
Any A list of formatted batches. For gRPC, each item is a batched NumPy array of shape (B, H, W, C) where B <= max_batch_size. For HTTP, each item is a JSON-serializable payload containing the base64 images in the format required by the PaddleOCR endpoint.
Raises
KeyError
If "image_arrays" is not found in data.
ValueError
If an invalid protocol is specified, or if the image shapes are inconsistent for gRPC batching.
name()
Get the name of the model interface.
Returns
str The name of the model interface, including the PaddleOCR version.
parse_output(response, protocol, data=None, **kwargs)
Parse the model's inference response for the given protocol. The parsing may handle batched outputs for multiple images.
Parameters
response : Any
The raw response from the PaddleOCR model.
protocol : str
The protocol used for inference, "grpc" or "http".
data : dict of str -> Any, optional
Additional data dictionary that may include "image_dims" for bounding box scaling.
**kwargs : Any
Additional keyword arguments, such as custom table_content_format.
Returns
Any The parsed output, typically a list of (content, table_content_format) tuples.
Raises
ValueError If an invalid protocol is specified.
prepare_data_for_inference(data)
Decode one or more base64-encoded images into NumPy arrays, storing them
alongside their dimensions in data.
Parameters
data : dict of str -> Any The input data containing either: - 'base64_image': a single base64-encoded image, or - 'base64_images': a list of base64-encoded images.
Returns
dict of str -> Any The updated data dictionary with the following keys added: - "image_arrays": List of decoded NumPy arrays of shape (H, W, C). - "image_dims": List of (height, width) tuples for each decoded image.
Raises
KeyError
If neither 'base64_image' nor 'base64_images' is found in data.
ValueError
If 'base64_images' is present but is not a list.
process_inference_results(output, **kwargs)
Process inference results for the PaddleOCR model.
Parameters
output : Any The raw output parsed from the PaddleOCR model. **kwargs : Any Additional keyword arguments for customization.
Returns
Any The post-processed inference results. By default, this simply returns the output as the table content (or content list).
yolox
YOLOX_CONF_THRESHOLD = 0.01
module-attribute
YOLOX_FINAL_SCORE = 0.48
module-attribute
YOLOX_IMAGE_PREPROC_HEIGHT = 1024
module-attribute
YOLOX_IMAGE_PREPROC_WIDTH = 1024
module-attribute
YOLOX_IOU_THRESHOLD = 0.5
module-attribute
YOLOX_MAX_BATCH_SIZE = 8
module-attribute
YOLOX_MAX_HEIGHT = 1536
module-attribute
YOLOX_MAX_WIDTH = 1536
module-attribute
YOLOX_MIN_SCORE = 0.1
module-attribute
YOLOX_NIM_MAX_IMAGE_SIZE = 512000
module-attribute
YOLOX_NUM_CLASSES = 3
module-attribute
logger = logging.getLogger(__name__)
module-attribute
YoloxPageElementsModelInterface
Bases: ModelInterface
An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols.
format_input(data, protocol, max_batch_size, **kwargs)
Format input data for the specified protocol, returning a list of batches each up to 'max_batch_size' in length.
Parameters
data : dict The input data to format. protocol : str The protocol to use ("grpc" or "http"). max_batch_size : int The maximum batch size to respect.
Returns
List[Any] A list of batches, each formatted according to the protocol.
name()
Returns the name of the Yolox model interface.
Returns
str The name of the model interface.
parse_output(response, protocol, data=None, **kwargs)
Parse the output from the model's inference response.
Parameters
response : Any The response from the model inference. protocol : str The protocol used ("grpc" or "http"). data : dict, optional Additional input data passed to the function.
Returns
Any The parsed output data.
Raises
ValueError If an invalid protocol is specified or the response format is unexpected.
prepare_data_for_inference(data)
Prepare input data for inference by resizing images and storing their original shapes.
Parameters
data : dict The input data containing a list of images.
Returns
dict The updated data dictionary with resized images and original image shapes.
process_inference_results(output, protocol, **kwargs)
Process the results of the Yolox model inference and return the final annotations.
Parameters
output_array : np.ndarray The raw output from the Yolox model. kwargs : dict Additional parameters for processing, including thresholds and number of classes.
Returns
list[dict] A list of annotation dictionaries for each image in the batch.
bb_iou_array(boxes, new_box)
chunkify(lst, chunk_size)
expand_boxes(boxes, r_x=1, r_y=1)
expand_chart_bboxes(annotation_dict, labels=None)
Expand bounding boxes of charts and titles based on the bounding boxes of the other class. Args: annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
Returns:
| Name | Type | Description |
|---|---|---|
annotation_dict |
same as input, with expanded bboxes for charts |
expand_table_bboxes(annotation_dict, labels=None)
Additional preprocessing for tables: extend the upper bounds to capture titles if any. Args: annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
Returns:
| Name | Type | Description |
|---|---|---|
annotation_dict |
same as input, with expanded bboxes for charts |
find_matching_box_fast(boxes_list, new_box, match_iou)
Reimplementation of find_matching_box with numpy instead of loops. Gives significant speed up for larger arrays (~100x). This was previously the bottleneck since the function is called for every entry in the array.
get_biggest_box(boxes, conf_type='avg')
Merges boxes by using the biggest box.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
boxes
|
np array [n x 8]
|
Boxes to merge. |
required |
conf_type
|
str
|
Confidence merging type. Defaults to "avg". |
'avg'
|
Returns:
| Type | Description |
|---|---|
|
np array [8]: Merged box. |
get_weighted_box(boxes, conf_type='avg')
Merges boxes by using the weighted fusion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
boxes
|
np array [n x 8]
|
Boxes to merge. |
required |
conf_type
|
str
|
Confidence merging type. Defaults to "avg". |
'avg'
|
Returns:
| Type | Description |
|---|---|
|
np array [8]: Merged box. |
match_with_title(chart_bbox, title_bboxes, iou_th=0.01)
merge_boxes(b1, b2)
merge_labels(labels, confs)
Custom function for merging labels. If all labels are the same, return the unique value. Else, return the label of the most confident non-title (class 2) box.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
labels
|
np array [n]
|
Labels. |
required |
confs
|
np array [n]
|
Confidence. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
int |
Label. |
postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False)
postprocess_results(results, original_image_shapes, min_score=0.0)
For each item (==image) in results, computes annotations in the form
{"table": [[0.0107, 0.0859, 0.7537, 0.1219, 0.9861], ...], "figure": [...], "title": [...] } where each list of 5 floats represents a bounding box in the format [x1, y1, x2, y2, confidence]
Keep only bboxes with high enough confidence.
prefilter_boxes(boxes, scores, labels, weights, thr, class_agnostic=False)
Reformats and filters boxes. Output is a dict of boxes to merge separately.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
boxes
|
list[np array[n x 4]]
|
List of boxes. One list per model. |
required |
scores
|
list[np array[n]]
|
List of confidences. |
required |
labels
|
list[np array[n]]
|
List of labels. |
required |
weights
|
list
|
Model weights. |
required |
thr
|
float
|
Confidence threshold |
required |
class_agnostic
|
bool
|
If True, merge boxes from different classes. Defaults to False. |
False
|
Returns:
| Type | Description |
|---|---|
|
dict[np array [? x 8]]: Filtered boxes. |
resize_image(image, target_img_size)
weighted_boxes_fusion(boxes_list, scores_list, labels_list, iou_thr=0.5, skip_box_thr=0.0, conf_type='avg', merge_type='weighted', class_agnostic=False)
Custom wbf implementation that supports a class_agnostic mode and a biggest box fusion. Boxes are expected to be in normalized (x0, y0, x1, y1) format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
boxes_list
|
list[np array[n x 4]]
|
List of boxes. One list per model. |
required |
scores_list
|
list[np array[n]]
|
List of confidences. |
required |
labels_list
|
list[np array[n]]
|
List of labels |
required |
iou_thr
|
float
|
IoU threshold for matching. Defaults to 0.55. |
0.5
|
skip_box_thr
|
float
|
Exclude boxes with score < skip_box_thr. Defaults to 0.0. |
0.0
|
conf_type
|
str
|
Confidence merging type. Defaults to "avg". |
'avg'
|
merge_type
|
str
|
Merge type "weighted" or "biggest". Defaults to "weighted". |
'weighted'
|
class_agnostic
|
bool
|
If True, merge boxes from different classes. Defaults to False. |
False
|
Returns:
| Type | Description |
|---|---|
|
np array[N x 4]: Merged boxes, |
|
|
np array[N]: Merged confidences, |
|
|
np array[N]: Merged labels. |