predict
#
During evaluation or stress testing, Efemarai analyzes the predictions made by the model in order to find potential problems.
Example
The following definition
predict:
entrypoint: inference:predict
inputs:
- name: model
value: ${model.runtime.load.output.model}
- name: datapoints
value: ${datapoints}
- device: "gpu" #or "cpu"
output:
name: predictions
describes a Python function in inference.py
(expected to be in the model repo)
which for example could look like
import efemarai as ef
def output_to_sdk(image, output):
# Convert model output values to sdk objects.
sdk_outputs = []
for k, v in output.items():
for i, field in enumerate(v):
label = ef.AnnotationClass(id=int(output["classes"][i]))
if k == "boxes":
sdk_outputs.append(
ef.BoundingBox(
xyxy=field,
label=label,
confidence=output["scores"][i],
instance_id=i,
key_name=k,
ref_field=image,
)
)
return sdk_outputs
def predict(datapoints, images, device):
""" Loads the model.
Args:
model (CarDetector): Loaded car detector model.
datapoints (List[ef.Datapoint]): List of ef.Datapoint which holds all the information.
device (str): Device name passed in by default.
Returns:
A list of dicts with the predictions for each image in the batch.
"""
# Extract the images from the datapoints
images = [dp.get_inputs_with_type(ef.fields.Image) for dp in datapoints]
# Pre-process batch of input images
images = torch.stack([torch.tensor(image) for image in images])
images = images.to(device).permute(0, 3, 1, 2) / 255.0
# Perform inference
output = model(images)
# Get required outputs
outputs = [
{
"boxes": detections[::, :4].reshape(-1, 4).tolist(),
"scores": detections[::, 4:5].reshape(-1).tolist(),
"classes": detections[::, 5:6].reshape(-1).tolist(),
}
for detections in output
]
# Convert outputs to ef.BaseFields
outputs = [output_to_sdk(images[i], y_s) for i, y_s in enumerate(outputs)]
return outputs
In this case model
is the output of the load
function, images
is a batch of images automatically aggregated by Efemarai.
See Variables for more details.
The output from the predict
function should be a list of dictionaries
containing the keys defined in the yaml (classes
, scores
are required, with
boxes
,masks
being optional depending on the problem type).
Properties
entrypoint
: specifies where the user function is in the formatuser_module:user_function
. The module must be importable from the root of the model repository and it must contain the specified function. Sub-modules are also supported with the standard import syntax using dots e.g.module.submodule:function
inputs
: specifies a list of name-value pairs that will be passed as input arguments to the user function. Names must be valid Python variable names. Values are parsed from the YAML file and directly passed to the user function unless they refer to runtimes variable in which case they are substituted with the correct values first (see Variables). The runtime device is passed by default to the user function so it must not be explicitly specified as an input (seeruntime
).output
: specifies the output of the user function. It should contain a list of ef.BaseFields.