Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions clarifai/client/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from clarifai.errors import UserError
from clarifai.runners.utils import code_script, method_signatures
from clarifai.runners.utils.method_signatures import (
RESERVED_PARAM_WITH_PROTO,
CompatibilitySerializer,
deserialize,
get_stream_from_signature,
Expand Down Expand Up @@ -204,6 +205,9 @@ def _define_functions(self):

def bind_f(method_name, method_argnames, call_func, async_call_func):
def sync_f(*args, **kwargs):
# Extract with_proto parameter if present
with_proto = kwargs.pop(RESERVED_PARAM_WITH_PROTO, False)

if len(args) > len(method_argnames):
raise TypeError(
f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
Expand All @@ -221,18 +225,21 @@ def sync_f(*args, **kwargs):
)
if is_batch_input_valid and (not is_openai_chat_format(batch_inputs)):
# If the batch input is valid, call the function with the batch inputs and the method name
return call_func(batch_inputs, method_name)
return call_func(batch_inputs, method_name, with_proto=with_proto)

for name, arg in zip(
method_argnames, args
): # handle positional with zip shortest
if name in kwargs:
raise TypeError(f"Multiple values for argument {name}")
kwargs[name] = arg
return call_func(kwargs, method_name)
return call_func(kwargs, method_name, with_proto=with_proto)

async def async_f(*args, **kwargs):
# Async version to call the async function
# Extract with_proto parameter if present
with_proto = kwargs.pop(RESERVED_PARAM_WITH_PROTO, False)

if len(args) > len(method_argnames):
raise TypeError(
f"{method_name}() takes {len(method_argnames)} positional arguments but {len(args)} were given"
Expand All @@ -249,7 +256,9 @@ async def async_f(*args, **kwargs):
)
if is_batch_input_valid and (not is_openai_chat_format(batch_inputs)):
# If the batch input is valid, call the function with the batch inputs and the method name
return async_call_func(batch_inputs, method_name)
return async_call_func(
batch_inputs, method_name, with_proto=with_proto
)

for name, arg in zip(
method_argnames, args
Expand All @@ -258,7 +267,7 @@ async def async_f(*args, **kwargs):
raise TypeError(f"Multiple values for argument {name}")
kwargs[name] = arg

return async_call_func(kwargs, method_name)
return async_call_func(kwargs, method_name, with_proto=with_proto)

class MethodWrapper:
def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -364,6 +373,7 @@ def _predict(
self,
inputs, # TODO set up functions according to fetched signatures?
method_name: str = 'predict',
with_proto: bool = False,
) -> Any:
input_signature = self._method_signatures[method_name].input_fields
output_signature = self._method_signatures[method_name].output_fields
Expand All @@ -385,9 +395,12 @@ def _predict(
outputs = []
for output in response.outputs:
outputs.append(deserialize(output.data, output_signature, is_output=True))
if batch_input:
return outputs
return outputs[0]

result = outputs if batch_input else outputs[0]

if with_proto:
return result, response
return result

def _predict_by_proto(
self,
Expand Down Expand Up @@ -448,15 +461,17 @@ async def _async_predict(
self,
inputs,
method_name: str = 'predict',
with_proto: bool = False,
) -> Any:
"""Asynchronously process inputs and make predictions.

Args:
inputs: Input data to process
method_name (str): Name of the method to call
with_proto (bool): If True, return both the processed result and the raw protobuf response

Returns:
Processed prediction results
Processed prediction results, optionally with protobuf response
"""
# method_name is set to 'predict' by default, this is because to replicate the input and output signature behaviour of sync to async predict.
input_signature = self._method_signatures[method_name].input_fields
Expand All @@ -477,7 +492,11 @@ async def _async_predict(
for output in response.outputs:
outputs.append(deserialize(output.data, output_signature, is_output=True))

return outputs if batch_input else outputs[0]
result = outputs if batch_input else outputs[0]

if with_proto:
return result, response
return result

async def _async_predict_by_proto(
self,
Expand Down Expand Up @@ -551,6 +570,7 @@ def _generate(
self,
inputs, # TODO set up functions according to fetched signatures?
method_name: str = 'generate',
with_proto: bool = False,
) -> Any:
input_signature = self._method_signatures[method_name].input_fields
output_signature = self._method_signatures[method_name].output_fields
Expand All @@ -572,10 +592,13 @@ def _generate(
outputs = []
for output in response.outputs:
outputs.append(deserialize(output.data, output_signature, is_output=True))
if batch_input:
yield outputs

result = outputs if batch_input else outputs[0]

if with_proto:
yield result, response
else:
yield outputs[0]
yield result

def _generate_by_proto(
self,
Expand Down Expand Up @@ -641,6 +664,7 @@ async def _async_generate(
self,
inputs,
method_name: str = 'generate',
with_proto: bool = False,
) -> Any:
# method_name is set to 'generate' by default, this is because to replicate the input and output signature behaviour of sync to async generate.
input_signature = self._method_signatures[method_name].input_fields
Expand All @@ -654,18 +678,21 @@ async def _async_generate(
proto_inputs = []
for input in inputs:
proto = resources_pb2.Input()
serialize(input, input_signature, proto.data)
proto_inputs.append(proto)
serialize(input, input_signature, proto.data)
proto_inputs.append(proto)
response_stream = self._async_generate_by_proto(proto_inputs, method_name)

async for response in response_stream:
outputs = []
for output in response.outputs:
outputs.append(deserialize(output.data, output_signature, is_output=True))
if batch_input:
yield outputs

result = outputs if batch_input else outputs[0]

if with_proto:
yield result, response
else:
yield outputs[0]
yield result

async def _async_generate_by_proto(
self,
Expand Down Expand Up @@ -734,6 +761,7 @@ def _stream(
self,
inputs,
method_name: str = 'stream',
with_proto: bool = False,
) -> Any:
input_signature = self._method_signatures[method_name].input_fields
output_signature = self._method_signatures[method_name].output_fields
Expand Down Expand Up @@ -775,7 +803,12 @@ def _input_proto_stream():

for response in response_stream:
assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
yield deserialize(response.outputs[0].data, output_signature, is_output=True)
result = deserialize(response.outputs[0].data, output_signature, is_output=True)

if with_proto:
yield result, response
else:
yield result

def _req_iterator(
self,
Expand Down Expand Up @@ -843,6 +876,7 @@ async def _async_stream(
self,
inputs,
method_name: str = 'stream',
with_proto: bool = False,
) -> Any:
# method_name is set to 'stream' by default, this is because to replicate the input and output signature behaviour of sync to async stream.
input_signature = self._method_signatures[method_name].input_fields
Expand Down Expand Up @@ -885,7 +919,12 @@ async def _input_proto_stream():

async for response in response_stream:
assert len(response.outputs) == 1, 'streaming methods must have exactly one output'
yield deserialize(response.outputs[0].data, output_signature, is_output=True)
result = deserialize(response.outputs[0].data, output_signature, is_output=True)

if with_proto:
yield result, response
else:
yield result

async def _async_stream_by_proto(
self,
Expand Down
9 changes: 9 additions & 0 deletions clarifai/runners/utils/method_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
TupleSerializer,
)

# Reserved parameter name for protobuf response access
RESERVED_PARAM_WITH_PROTO = 'with_proto'


def build_function_signature(func):
'''
Expand All @@ -45,6 +48,12 @@ def build_function_signature(func):
input_sigs = []
input_streaming = []
for p in sig.parameters.values():
# Validate that user methods don't use reserved parameter names
if p.name == RESERVED_PARAM_WITH_PROTO:
raise ValueError(
f"Parameter name '{RESERVED_PARAM_WITH_PROTO}' is reserved and cannot be used in model methods. "
f"This parameter is automatically added by the framework to provide access to protobuf responses."
)
model_type_field, _, streaming = build_variable_signature(p.name, p.annotation, p.default)
input_sigs.append(model_type_field)
input_streaming.append(streaming)
Expand Down
139 changes: 139 additions & 0 deletions examples/with_proto_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""
Example demonstrating the with_proto functionality in Clarifai Python SDK

This feature allows pythonic models to return both the processed result
and the raw protobuf response from the API.
"""

import os

from clarifai.client.model import Model


def demo_with_proto_functionality():
"""
Demonstrate how to use with_proto parameter with pythonic models
"""
print("=== Clarifai with_proto Feature Demo ===\n")

# Example model URL and configuration
# Note: Replace with actual model URL and credentials for real usage
model_url = "https://clarifai.com/user/app/models/my-pythonic-model"
deployment_id = "my-deployment-id"
pat = os.getenv('CLARIFAI_PAT')
if not pat:
pat = 'your-pat-token' # Placeholder for demo purposes

try:
# Initialize the model client
model = Model(
url=model_url,
pat=pat,
deployment_id=deployment_id,
)

print("Model initialized successfully!\n")

# Example 1: Basic predict without with_proto (existing behavior)
print("1. Standard predict call (existing behavior):")
print(" response = model.predict(prompt='What is AI?')")
print(" # Returns only the processed result\n")

# Example 2: Predict with with_proto=True (NEW feature)
print("2. Predict with protobuf response (NEW feature):")
print(" response, proto = model.predict(")
print(" prompt='What is AI?',")
print(" with_proto=True # <- This is the new parameter")
print(" )")
print(" # Returns tuple: (processed_result, raw_protobuf_response)\n")

# Example 3: Generate streaming with with_proto
print("3. Streaming generate with protobuf:")
print(" for response, proto in model.generate(")
print(" prompt='Explain quantum computing',")
print(" with_proto=True")
print(" ):")
print(" print(f'Generated: {response}')")
print(" print(f'Status: {proto.status.code}')")
print(" # Access raw metadata like timestamps, request IDs, etc.\n")

# Example 4: Custom model method with with_proto
print("4. Custom model methods support with_proto:")
print(" # Any pythonic model method automatically supports with_proto")
print(" result, proto = model.my_custom_method(")
print(" input_data='some data',")
print(" temperature=0.7,")
print(" with_proto=True")
print(" )")
print(" # Works with any method defined in the ModelClass\n")

# Benefits section
print("Benefits of with_proto=True:")
print("- Access to complete API response metadata")
print("- Debugging capabilities (status codes, request IDs)")
print("- Performance metrics (latency, processing info)")
print("- Raw data for advanced use cases")
print("- Backward compatible (existing code unchanged)")

except Exception as e:
print(f"Demo setup note: {e}")
print("This is a demonstration script showing the API usage.")
print("For actual usage, ensure you have valid model URLs and credentials.\n")

# Show the implementation regardless
show_implementation_details()


def show_implementation_details():
"""Show technical implementation details"""
print("\n=== Implementation Details ===\n")

print("The with_proto feature is implemented by modifying the ModelClient class:")
print("1. All method binding now extracts 'with_proto' parameter")
print("2. Core methods (_predict, _generate, _stream) accept with_proto")
print("3. When with_proto=True, methods return (result, proto_response)")
print("4. When with_proto=False (default), methods return just result")
print("5. Fully backward compatible - no existing code needs changes\n")

print("Supported methods:")
print("- Synchronous: predict(), generate(), stream(), custom_method()")
print("- Asynchronous: async_predict(), async_generate(), async_stream()")
print("- All methods support with_proto parameter consistently\n")


def technical_example():
"""Show a technical example of what the protobuf response contains"""
print("=== What's in the Protobuf Response? ===\n")

print("The protobuf response contains rich metadata:")
print("""
proto.status.code # Success/error status
proto.status.description # Human readable status
proto.status.details # Additional error details
proto.status.req_id # Request ID for debugging
proto.status.percent_completed # Progress for long operations

proto.outputs[0].data # Raw output data
proto.outputs[0].status # Per-output status
proto.model.id # Model information
proto.model.model_version.id # Model version used

# And much more depending on the model and operation
""")

print("Example usage for debugging:")
print("""
try:
result, proto = model.predict(text="Hello", with_proto=True)
print(f"Success! Request ID: {proto.status.req_id}")
except Exception as e:
print(f"Error: {proto.status.description}")
print(f"Debug with Request ID: {proto.status.req_id}")
""")


if __name__ == '__main__':
demo_with_proto_functionality()
technical_example()
print("\n✅ Demo complete! The with_proto feature is ready to use.")
Loading
Loading