Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ realtime = [

workflow_payload_offloading_azure = [
"azure-storage-blob[aio]>=12.28.0,<13.0.0",
"azure-identity[aio]>=1.25.0,<2.0.0",
]
workflow_payload_offloading_gcs = [
"gcloud-aio-storage>=9.3.0,<10.0.0",
Expand Down
1 change: 1 addition & 0 deletions src/mistralai/extra/workflows/encoding/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class BlobStorageConfig(BaseModel):
# Azure settings
container_name: Optional[str] = None
azure_connection_string: Optional[SecretStr] = None
azure_storage_account_url: Optional[str] = None

# GCS settings
bucket_id: Optional[str] = None
Expand Down
29 changes: 25 additions & 4 deletions src/mistralai/extra/workflows/encoding/storage/_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, cast

from azure.core.exceptions import ResourceNotFoundError
from azure.identity.aio import DefaultAzureCredential
from azure.storage.blob.aio import BlobServiceClient
from .blob_storage import BlobNotFoundError, BlobStorage

Expand All @@ -11,14 +12,25 @@ class AzureBlobStorage(BlobStorage):
def __init__(
self,
container_name: str,
azure_connection_string: str,
azure_connection_string: str | None = None,
prefix: str | None = None,
azure_storage_account_url: str | None = None,
):
if azure_connection_string and azure_storage_account_url:
raise ValueError(
"azure_connection_string and azure_storage_account_url are mutually exclusive"
)
if not azure_connection_string and not azure_storage_account_url:
raise ValueError(
"Either azure_connection_string or azure_storage_account_url must be provided"
)
self.container_name = container_name
self.connection_string = azure_connection_string
self.account_url = azure_storage_account_url
self.prefix = prefix or ""
self._service_client: BlobServiceClient | None = None
self._container_client: Any = None
self._credential: Any = None

def _get_full_key(self, key: str) -> str:
if not self.prefix:
Expand All @@ -28,9 +40,16 @@ def _get_full_key(self, key: str) -> str:
return f"{self.prefix}/{key}"

async def __aenter__(self) -> "AzureBlobStorage":
self._service_client = BlobServiceClient.from_connection_string(
self.connection_string
)
if self.connection_string:
self._service_client = BlobServiceClient.from_connection_string(
self.connection_string
)
else:
assert self.account_url is not None
self._credential = DefaultAzureCredential()
self._service_client = BlobServiceClient(
self.account_url, credential=self._credential
)
assert self._service_client is not None
self._container_client = self._service_client.get_container_client(
self.container_name
Expand All @@ -40,6 +59,8 @@ async def __aenter__(self) -> "AzureBlobStorage":
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._service_client:
await self._service_client.close()
if self._credential:
await self._credential.close()

async def upload_blob(self, key: str, content: bytes) -> str:
full_key = self._get_full_key(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ async def get_blob_storage(
from ._azure import AzureBlobStorage # type: ignore[import-untyped]
except ImportError as e:
raise ImportError(
"Azure Blob Storage support requires azure-storage-blob. "
"Install it with: pip install 'mistralai[workflow_payload_offloading_azure]'"
"Azure Blob Storage support requires azure-storage-blob and azure-identity. "
"Install with: pip install 'mistralai[workflow_payload_offloading_azure]'"
) from e

if not blob_storage_config.container_name:
Expand All @@ -78,14 +78,11 @@ async def get_blob_storage(
if blob_storage_config.azure_connection_string
else None
)
if not azure_conn_str:
raise WorkflowPayloadOffloadingException(
"azure_connection_string is required for Azure blob storage"
)
storage = AzureBlobStorage(
container_name=blob_storage_config.container_name,
azure_connection_string=azure_conn_str,
prefix=prefix,
azure_storage_account_url=blob_storage_config.azure_storage_account_url,
)

elif blob_storage_config.storage_provider == StorageProvider.GCS:
Expand Down
45 changes: 45 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading