Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,13 @@ print(f"Structured Response: {structured_response.choices[0].message.content}")

The following example shows how to create embeddings for an NVIDIA NIM.

> [!NOTE]
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.

```python
response = client.inference.embeddings(
model_id="nvidia/llama-3.2-nv-embedqa-1b-v2",
contents=["What is the capital of France?"],
task_type="query",
response = client.embeddings.create(
model="nvidia/llama-3.2-nv-embedqa-1b-v2",
input=["What is the capital of France?"],
extra_body={"input_type": "query"},
)
print(f"Embeddings: {response.embeddings}")
print(f"Embeddings: {response.data}")
```

### Vision Language Models Example
Expand Down
55 changes: 0 additions & 55 deletions llama_stack/providers/remote/inference/nvidia/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,6 @@
# the root directory of this source tree.


from openai import NOT_GIVEN

from llama_stack.apis.inference import (
OpenAIEmbeddingData,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

Expand Down Expand Up @@ -76,50 +68,3 @@ def get_base_url(self) -> str:
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url

async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
"""
OpenAI-compatible embeddings for NVIDIA NIM.

Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API.
We default this to "query" to ensure requests succeed when using the
OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with
`task_type='document'`.
"""
extra_body: dict[str, object] = {"input_type": "query"}
logger.warning(
"NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. "
"For passage embeddings, use the embeddings API with task_type='document'."
)

response = await self.client.embeddings.create(
model=await self._get_provider_model_id(params.model),
input=params.input,
encoding_format=params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
dimensions=params.dimensions if params.dimensions is not None else NOT_GIVEN,
user=params.user if params.user is not None else NOT_GIVEN,
extra_body=extra_body,
)

data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)

usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)

return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)
77 changes: 70 additions & 7 deletions tests/integration/inference/test_openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@

from llama_stack.core.library_client import LlamaStackAsLibraryClient

ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER = {
"remote::nvidia": [
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
],
}


def decode_base64_to_floats(base64_string: str) -> list[float]:
"""Helper function to decode base64 string to list of float32 values."""
Expand All @@ -29,6 +38,28 @@ def provider_from_model(client_with_models, model_id):
return providers[provider_id]


def is_asymmetric_model(client_with_models, model_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this used now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to put it in get_extra_body_for_model: if it is not an asymmetric model, return None for the extra_body.

provider = provider_from_model(client_with_models, model_id)
provider_type = provider.provider_type

if provider_type not in ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER:
return False

return model_id in ASYMMETRIC_EMBEDDING_MODELS_BY_PROVIDER[provider_type]


def get_extra_body_for_model(client_with_models, model_id, input_type="query"):
if not is_asymmetric_model(client_with_models, model_id):
return None

provider = provider_from_model(client_with_models, model_id)

if provider.provider_type == "remote::nvidia":
return {"input_type": input_type}

return None


def skip_if_model_doesnt_support_user_param(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in (
Expand All @@ -40,17 +71,29 @@ def skip_if_model_doesnt_support_user_param(client, model_id):

def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in (

should_skip = provider.provider_type in (
"remote::databricks", # param silently ignored, always returns floats
"remote::fireworks", # param silently ignored, always returns list of floats
"remote::ollama", # param silently ignored, always returns list of floats
):
) or (
provider.provider_type == "remote::nvidia"
and model_id
in [
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
]
)

if should_skip:
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")


def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if (

should_skip = (
provider.provider_type
in (
"remote::together", # returns 400
Expand All @@ -59,11 +102,19 @@ def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_i
"remote::databricks",
"remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384}
)
):
pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
or (provider.provider_type == "remote::openai" and "text-embedding-3" not in model_id)
or (
provider.provider_type == "remote::nvidia"
and model_id
in [
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
]
)
if provider.provider_type == "remote::openai" and "text-embedding-3" not in model_id:
)

if should_skip:
pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
)
Expand Down Expand Up @@ -105,6 +156,7 @@ def test_openai_embeddings_single_string(compat_client, client_with_models, embe
model=embedding_model_id,
input=input_text,
encoding_format="float",
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)

assert response.object == "list"
Expand All @@ -129,6 +181,7 @@ def test_openai_embeddings_multiple_strings(compat_client, client_with_models, e
model=embedding_model_id,
input=input_texts,
encoding_format="float",
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)

assert response.object == "list"
Expand All @@ -155,6 +208,7 @@ def test_openai_embeddings_with_encoding_format_float(compat_client, client_with
model=embedding_model_id,
input=input_text,
encoding_format="float",
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)

assert response.object == "list"
Expand All @@ -175,6 +229,7 @@ def test_openai_embeddings_with_dimensions(compat_client, client_with_models, em
model=embedding_model_id,
input=input_text,
dimensions=dimensions,
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)

assert response.object == "list"
Expand All @@ -196,6 +251,7 @@ def test_openai_embeddings_with_user_parameter(compat_client, client_with_models
model=embedding_model_id,
input=input_text,
user=user_id,
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)

assert response.object == "list"
Expand All @@ -212,6 +268,7 @@ def test_openai_embeddings_empty_list_error(compat_client, client_with_models, e
compat_client.embeddings.create(
model=embedding_model_id,
input=[],
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)


Expand All @@ -223,6 +280,7 @@ def test_openai_embeddings_invalid_model_error(compat_client, client_with_models
compat_client.embeddings.create(
model="invalid-model-id",
input="Test text",
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)


Expand All @@ -233,16 +291,19 @@ def test_openai_embeddings_different_inputs_different_outputs(compat_client, cli
input_text1 = "This is the first text"
input_text2 = "This is completely different content"

extra_body = get_extra_body_for_model(client_with_models, embedding_model_id)
response1 = compat_client.embeddings.create(
model=embedding_model_id,
input=input_text1,
encoding_format="float",
extra_body=extra_body,
)

response2 = compat_client.embeddings.create(
model=embedding_model_id,
input=input_text2,
encoding_format="float",
extra_body=extra_body,
)

embedding1 = response1.data[0].embedding
Expand All @@ -267,6 +328,7 @@ def test_openai_embeddings_with_encoding_format_base64(compat_client, client_wit
input=input_text,
encoding_format="base64",
dimensions=dimensions,
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)

# Validate response structure
Expand Down Expand Up @@ -298,6 +360,7 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
model=embedding_model_id,
input=input_texts,
encoding_format="base64",
extra_body=get_extra_body_for_model(client_with_models, embedding_model_id),
)
# Validate response structure
assert response.object == "list"
Expand Down
Loading