Skip to content

Shared Model Manager#

The SharedModelManager class is designed to manage and facilitate the use of machine learning models across different devices, such as CPUs and GPUs, within an asynchronous environment. It ensures safe and efficient execution of these models, particularly in scenarios where GPU resources need to be shared exclusively among multiple models. The manager coordinates access to the shared GPU, preventing conflicts when multiple models require it. Models are only loaded into memory when needed using the fetch_model function.

  • add(): Registers a machine learning model class with the manager. The actual model instance is not loaded at this point.
  • fetch_model(): Retrieves the previously added model class and creates (loads) the actual model instance. This function utilizes PyTorch interface to, to handle device (CPU/GPU) allocation based on availability.

The usage example demonstrates adding models and then using them with their respective functionalities.

⚠️ ❕: We should ALWAYS add model instance on CPU to the pool. This avoids overwhelming the GPU memory, and model pool will automatically put it in GPU when the model is fetched..

model_pool = SharedModelManager()

# Add models instance to the pool
model_pool.add(QRReader())
model_pool.add(Owlv2(model_config=OWLV2Config(device=Device.CPU)))

# Read image
image = Image.open("path/to/your/image.jpg")

# Use QRReader model
async def use_qr_reader():
    # Read image
    image = Image.open("path/to/your/image.jpg")

    qr_reader = await model_pool.fetch_model(QRReader.__name__)
    detections = qr_reader(image)
    # Process detections ...

# Use Owlv2 model
async def use_owlv2():
    # Read image
    image = Image.open("path/to/your/image.jpg")

    owlv2 = await model_pool.fetch_model(Owlv2.__name__)
    prompts = ["a photo of a cat", "a photo of a dog"]
    results = owlv2(image, prompts=prompts)
    # Process results ...

SharedModelManager #

add(model) #

Adds a model to the pool with a device preference.

Parameters:

Name Type Description Default
model Basetool

The modal instance to be added to the pool, it should implement the BaseTool interface.

required
device Device

The preferred device for the model.

required

Returns:

Name Type Description
str str

The model ID to be used for fetching the model.

fetch_model(model_id) #

Retrieves a model from the pool for safe execution.

Parameters:

Name Type Description Default
model_id str

Id to access the model in the pool.

required

Returns:

Name Type Description
Any BaseTool

The retrieved model instance.