From 7cfb7ba814859b0754674949abd2f5bd0155f0a2 Mon Sep 17 00:00:00 2001 From: Jinash Rouniyar Date: Wed, 1 Oct 2025 22:06:59 -0400 Subject: [PATCH] Feat: Added Contextual AI's Generative and Reranker Client --- src/collections/config/integration.test.ts | 39 +++ src/collections/config/types/generative.ts | 13 + src/collections/config/types/reranker.ts | 15 ++ src/collections/configure/generative.ts | 27 ++ src/collections/configure/reranker.ts | 17 ++ src/collections/configure/types/generative.ts | 13 + src/collections/configure/unit.test.ts | 60 +++++ src/collections/generate/config.ts | 26 ++ src/collections/generate/integration.test.ts | 238 ++++++++++++++++++ src/collections/generate/mock.test.ts | 2 + src/collections/generate/unit.test.ts | 35 +++ src/collections/query/integration.test.ts | 169 +++++++++++++ src/collections/types/generate.ts | 21 +- 13 files changed, 674 insertions(+), 1 deletion(-) diff --git a/src/collections/config/integration.test.ts b/src/collections/config/integration.test.ts index 5836fd1c..2acbeb9b 100644 --- a/src/collections/config/integration.test.ts +++ b/src/collections/config/integration.test.ts @@ -10,6 +10,7 @@ import { PropertyConfig, RQConfig, RerankerCohereConfig, + RerankerContextualAIConfig, VectorIndexConfigDynamic, VectorIndexConfigHNSW, } from './types/index.js'; @@ -787,6 +788,44 @@ describe('Testing of the collection.config namespace', () => { model: 'model', }, }); + + await collection.config.update({ + reranker: weaviate.reconfigure.reranker.contextualai({ + model: 'ctxl-rerank-v2-instruct-multilingual', + }), + }); + + config = await collection.config.get(); + expect(config.reranker).toEqual>({ + name: 'reranker-contextualai', + config: { + model: 'ctxl-rerank-v2-instruct-multilingual', + }, + }); + + await collection.config.update({ + generative: weaviate.reconfigure.generative.contextualai({ + model: 'v2', + maxTokens: 100, + temperature: 0.7, + topP: 0.9, + systemPrompt: 'sys', + avoidCommentary: false, + }), + }); + + config = await collection.config.get(); + expect(config.generative).toEqual>({ + name: 'generative-contextualai', + config: { + model: 'v2', + maxTokensProperty: 100, + temperatureProperty: 0.7, + topPProperty: 0.9, + systemPromptProperty: 'sys', + avoidCommentaryProperty: false, + }, + }); }); requireAtLeast(1, 31, 0).it( diff --git a/src/collections/config/types/generative.ts b/src/collections/config/types/generative.ts index 239009e6..042c77fb 100644 --- a/src/collections/config/types/generative.ts +++ b/src/collections/config/types/generative.ts @@ -106,12 +106,22 @@ export type GenerativeXAIConfig = { topP?: number; }; +export type GenerativeContextualAIConfig = { + model?: string; + maxTokensProperty?: number; + temperatureProperty?: number; + topPProperty?: number; + systemPromptProperty?: string; + avoidCommentaryProperty?: boolean; +}; + export type GenerativeConfig = | GenerativeAnthropicConfig | GenerativeAnyscaleConfig | GenerativeAWSConfig | GenerativeAzureOpenAIConfig | GenerativeCohereConfig + | GenerativeContextualAIConfig | GenerativeDatabricksConfig | GenerativeGoogleConfig | GenerativeFriendliAIConfig @@ -133,6 +143,8 @@ export type GenerativeConfigType = G extends 'generative-anthropic' ? GenerativeAzureOpenAIConfig : G extends 'generative-cohere' ? GenerativeCohereConfig + : G extends 'generative-contextualai' + ? GenerativeContextualAIConfig : G extends 'generative-databricks' ? GenerativeDatabricksConfig : G extends 'generative-google' @@ -162,6 +174,7 @@ export type GenerativeSearch = | 'generative-aws' | 'generative-azure-openai' | 'generative-cohere' + | 'generative-contextualai' | 'generative-databricks' | 'generative-google' | 'generative-friendliai' diff --git a/src/collections/config/types/reranker.ts b/src/collections/config/types/reranker.ts index 2357e490..737625fa 100644 --- a/src/collections/config/types/reranker.ts +++ b/src/collections/config/types/reranker.ts @@ -24,8 +24,20 @@ export type RerankerNvidiaConfig = { model?: 'nvidia/rerank-qa-mistral-4b' | string; }; +export type RerankerContextualAIConfig = { + baseURL?: string; + model?: + | 'ctxl-rerank-v2-instruct-multilingual' + | 'ctxl-rerank-v2-instruct-multilingual-mini' + | 'ctxl-rerank-v1-instruct' + | string; + instruction?: string; + topN?: number; +}; + export type RerankerConfig = | RerankerCohereConfig + | RerankerContextualAIConfig | RerankerJinaAIConfig | RerankerNvidiaConfig | RerankerTransformersConfig @@ -35,6 +47,7 @@ export type RerankerConfig = export type Reranker = | 'reranker-cohere' + | 'reranker-contextualai' | 'reranker-jinaai' | 'reranker-nvidia' | 'reranker-transformers' @@ -48,6 +61,8 @@ export type RerankerConfigType = R extends 'reranker-cohere' ? RerankerJinaAIConfig : R extends 'reranker-nvidia' ? RerankerNvidiaConfig + : R extends 'reranker-contextualai' + ? RerankerContextualAIConfig : R extends 'reranker-transformers' ? RerankerTransformersConfig : R extends 'reranker-voyageai' diff --git a/src/collections/configure/generative.ts b/src/collections/configure/generative.ts index fd03cba7..90c65e01 100644 --- a/src/collections/configure/generative.ts +++ b/src/collections/configure/generative.ts @@ -4,6 +4,7 @@ import { GenerativeAnyscaleConfig, GenerativeAzureOpenAIConfig, GenerativeCohereConfig, + GenerativeContextualAIConfig, GenerativeDatabricksConfig, GenerativeFriendliAIConfig, GenerativeGoogleConfig, @@ -21,6 +22,7 @@ import { GenerativeAnyscaleConfigCreate, GenerativeAzureOpenAIConfigCreate, GenerativeCohereConfigCreate, + GenerativeContextualAIConfigCreate, GenerativeDatabricksConfigCreate, GenerativeFriendliAIConfigCreate, GenerativeMistralConfigCreate, @@ -48,6 +50,31 @@ export default { config, }; }, + /** + * Create a `ModuleConfig<'generative-contextualai', GenerativeContextualAIConfig | undefined>` object for use when performing AI generation using the `generative-contextualai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/contextualai/generative) for detailed usage. + * + * @param {GenerativeContextualAIConfigCreate} [config] The configuration for the `generative-contextualai` module. + * @returns {ModuleConfig<'generative-contextualai', GenerativeContextualAIConfig | undefined>} The configuration object. + */ + contextualai: ( + config?: GenerativeContextualAIConfigCreate + ): ModuleConfig<'generative-contextualai', GenerativeContextualAIConfig | undefined> => { + return { + name: 'generative-contextualai', + config: config + ? { + model: config.model, + maxTokensProperty: config.maxTokens, + temperatureProperty: config.temperature, + topPProperty: config.topP, + systemPromptProperty: config.systemPrompt, + avoidCommentaryProperty: config.avoidCommentary, + } + : undefined, + }; + }, /** * Create a `ModuleConfig<'generative-anyscale', GenerativeAnyscaleConfig | undefined>` object for use when performing AI generation using the `generative-anyscale` module. * diff --git a/src/collections/configure/reranker.ts b/src/collections/configure/reranker.ts index 3c750866..a6e03ada 100644 --- a/src/collections/configure/reranker.ts +++ b/src/collections/configure/reranker.ts @@ -1,6 +1,7 @@ import { ModuleConfig, RerankerCohereConfig, + RerankerContextualAIConfig, RerankerJinaAIConfig, RerankerNvidiaConfig, RerankerVoyageAIConfig, @@ -23,6 +24,22 @@ export default { config: config, }; }, + /** + * Create a `ModuleConfig<'reranker-contextualai', RerankerContextualAIConfig>` object for use when reranking using the `reranker-contextualai` module. + * + * See the [documentation](https://weaviate.io/developers/weaviate/model-providers/contextualai/reranker) for detailed usage. + * + * @param {RerankerContextualAIConfig} [config] The configuration for the `reranker-contextualai` module. + * @returns {ModuleConfig<'reranker-contextualai', RerankerContextualAIConfig | undefined>} The configuration object. + */ + contextualai: ( + config?: RerankerContextualAIConfig + ): ModuleConfig<'reranker-contextualai', RerankerContextualAIConfig | undefined> => { + return { + name: 'reranker-contextualai', + config: config, + }; + }, /** * Create a `ModuleConfig<'reranker-jinaai', RerankerJinaAIConfig>` object for use when reranking using the `reranker-jinaai` module. * diff --git a/src/collections/configure/types/generative.ts b/src/collections/configure/types/generative.ts index 87ef3dae..3986aea8 100644 --- a/src/collections/configure/types/generative.ts +++ b/src/collections/configure/types/generative.ts @@ -2,6 +2,7 @@ import { GenerativeAWSConfig, GenerativeAnthropicConfig, GenerativeAnyscaleConfig, + GenerativeContextualAIConfig, GenerativeDatabricksConfig, GenerativeFriendliAIConfig, GenerativeMistralConfig, @@ -58,12 +59,22 @@ export type GenerativePaLMConfigCreate = GenerativePaLMConfig; export type GenerativeXAIConfigCreate = GenerativeXAIConfig; +export type GenerativeContextualAIConfigCreate = { + model?: string; + maxTokens?: number; + temperature?: number; + topP?: number; + systemPrompt?: string; + avoidCommentary?: boolean; +}; + export type GenerativeConfigCreate = | GenerativeAnthropicConfigCreate | GenerativeAnyscaleConfigCreate | GenerativeAWSConfigCreate | GenerativeAzureOpenAIConfigCreate | GenerativeCohereConfigCreate + | GenerativeContextualAIConfigCreate | GenerativeDatabricksConfigCreate | GenerativeFriendliAIConfigCreate | GenerativeMistralConfigCreate @@ -83,6 +94,8 @@ export type GenerativeConfigCreateType = G extends 'generative-anthropic' ? GenerativeAzureOpenAIConfigCreate : G extends 'generative-cohere' ? GenerativeCohereConfigCreate + : G extends 'generative-contextualai' + ? GenerativeContextualAIConfigCreate : G extends 'generative-databricks' ? GenerativeDatabricksConfigCreate : G extends 'generative-friendliai' diff --git a/src/collections/configure/unit.test.ts b/src/collections/configure/unit.test.ts index 30ec320d..fc091520 100644 --- a/src/collections/configure/unit.test.ts +++ b/src/collections/configure/unit.test.ts @@ -5,6 +5,7 @@ import { GenerativeAnyscaleConfig, GenerativeAzureOpenAIConfig, GenerativeCohereConfig, + GenerativeContextualAIConfig, GenerativeDatabricksConfig, GenerativeFriendliAIConfig, GenerativeGoogleConfig, @@ -14,6 +15,7 @@ import { GenerativeXAIConfig, ModuleConfig, RerankerCohereConfig, + RerankerContextualAIConfig, RerankerJinaAIConfig, RerankerNvidiaConfig, RerankerTransformersConfig, @@ -1931,6 +1933,38 @@ describe('Unit testing of the generative factory class', () => { }); }); + it('should create the correct GenerativeContextualAIConfig type with required & default values', () => { + const config = configure.generative.contextualai(); + expect(config).toEqual>( + { + name: 'generative-contextualai', + config: undefined, + } + ); + }); + + it('should create the correct GenerativeContextualAIConfig type with all values', () => { + const config = configure.generative.contextualai({ + model: 'v2', + maxTokens: 100, + temperature: 0.7, + topP: 0.9, + systemPrompt: 'system', + avoidCommentary: false, + }); + expect(config).toEqual>({ + name: 'generative-contextualai', + config: { + model: 'v2', + maxTokensProperty: 100, + temperatureProperty: 0.7, + topPProperty: 0.9, + systemPromptProperty: 'system', + avoidCommentaryProperty: false, + }, + }); + }); + it('should create the correct GenerativeCohereConfig type with all values', () => { const config = configure.generative.cohere({ k: 5, @@ -2265,6 +2299,32 @@ describe('Unit testing of the reranker factory class', () => { }); }); + it('should create the correct RerankerContextualAIConfig type using required & default values', () => { + const config = configure.reranker.contextualai(); + expect(config).toEqual>({ + name: 'reranker-contextualai', + config: undefined, + }); + }); + + it('should create the correct RerankerContextualAIConfig type with all values', () => { + const config = configure.reranker.contextualai({ + baseURL: 'https://api.contextual.ai', + model: 'ctxl-rerank-v2-instruct-multilingual', + instruction: 'Custom reranking instruction', + topN: 10, + }); + expect(config).toEqual>({ + name: 'reranker-contextualai', + config: { + baseURL: 'https://api.contextual.ai', + model: 'ctxl-rerank-v2-instruct-multilingual', + instruction: 'Custom reranking instruction', + topN: 10, + }, + }); + }); + it('should create the correct RerankerVoyageAIConfig type with all values', () => { const config = configure.reranker.voyageAI({ baseURL: 'base-url', diff --git a/src/collections/generate/config.ts b/src/collections/generate/config.ts index e6ecf97c..686b0cd6 100644 --- a/src/collections/generate/config.ts +++ b/src/collections/generate/config.ts @@ -6,6 +6,7 @@ import { GenerativeAnyscaleConfigRuntime, GenerativeCohereConfigRuntime, GenerativeConfigRuntimeType, + GenerativeContextualAIConfigRuntime, GenerativeDatabricksConfigRuntime, GenerativeFriendliAIConfigRuntime, GenerativeGoogleConfigRuntime, @@ -297,4 +298,29 @@ export const generativeParameters = { : undefined, }; }, + /** + * Create a `ModuleConfig<'generative-contextualai', GenerativeConfigRuntimeType<'generative-contextualai'> | undefined>` + * object for use when performing runtime-specific AI generation using the `generative-contextualai` module. + */ + contextualai( + config?: GenerativeContextualAIConfigRuntime + ): ModuleConfig< + 'generative-contextualai', + GenerativeConfigRuntimeType<'generative-contextualai'> | undefined + > { + // Contextual AI does not require special GRPC wrappers; pass primitives directly + return { + name: 'generative-contextualai', + config: config + ? { + model: config.model, + maxTokens: config.maxTokens, + temperature: config.temperature, + topP: config.topP, + systemPrompt: config.systemPrompt, + avoidCommentary: config.avoidCommentary, + } + : undefined, + }; + }, }; diff --git a/src/collections/generate/integration.test.ts b/src/collections/generate/integration.test.ts index f6f85066..49cb7e3e 100644 --- a/src/collections/generate/integration.test.ts +++ b/src/collections/generate/integration.test.ts @@ -551,3 +551,241 @@ maybe('Testing of the collection.generate methods with runtime generative config }); }); }); + +const maybeContextualAI = process.env.CONTEXTUALAI_API_KEY ? describe : describe.skip; + +maybeContextualAI('Testing of the collection.generate methods with Contextual AI', () => { + let client: WeaviateClient; + let collection: Collection; + const collectionName = 'TestCollectionContextualAI'; + let id: string; + + type TestCollectionContextualAI = { + title: string; + content: string; + category: string; + }; + + afterAll(() => { + return client.collections.delete(collectionName).catch((err) => { + console.error(err); + throw err; + }); + }); + + beforeAll(async () => { + client = await weaviate.connectToLocal({ + port: 8086, + grpcPort: 50057, + headers: { + 'X-Contextual-Api-Key': process.env.CONTEXTUALAI_API_KEY!, + 'X-Openai-Api-Key': process.env.OPENAI_APIKEY!, + }, + }); + collection = client.collections.use(collectionName); + id = await client.collections + .create({ + name: collectionName, + properties: [ + { + name: 'title', + dataType: 'text', + }, + { + name: 'content', + dataType: 'text', + }, + { + name: 'category', + dataType: 'text', + }, + ], + vectorizers: weaviate.configure.vectors.text2VecOpenAI(), + generative: weaviate.configure.generative.contextualai({ + model: 'v2', + maxTokens: 100, + temperature: 0.7, + topP: 0.9, + systemPrompt: 'You are a helpful AI assistant.', + avoidCommentary: false, + }), + }) + .then((c) => + c.data.insert({ + title: 'Machine Learning Fundamentals', + content: + 'Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed.', + category: 'AI/ML', + }) + ) + .then((r) => r); + }); + + it('should generate single prompt responses with proper text validation and content verification', async () => { + const response = await collection.generate.nearText( + 'What is machine learning?', + { + singlePrompt: 'Summarize this title in one sentence: {title}', + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative).toBeDefined(); + expect(response.objects[0].generative?.text).toBeDefined(); + expect(typeof response.objects[0].generative?.text).toBe('string'); + expect(response.objects[0].generative?.text?.length).toBeGreaterThan(0); + }); + + it('should handle grouped task generation with multiple properties and response aggregation', async () => { + const response = await collection.generate.nearText( + 'artificial intelligence', + { + groupedTask: 'What is the main topic of these documents?', + groupedProperties: ['title', 'content'], + }, + { + limit: 1, + } + ); + + expect(response.generative).toBeDefined(); + expect(response.generative?.text).toBeDefined(); + expect(typeof response.generative?.text).toBe('string'); + expect(response.generative?.text?.length).toBeGreaterThan(0); + }); + + it('should validate runtime configuration parameters and generation behavior', async () => { + const response = await collection.generate.nearText( + 'machine learning', + { + singlePrompt: 'Translate this title to French: {title}', + config: generativeParameters.contextualai({ + model: 'v2', + maxTokens: 50, + temperature: 0.5, + topP: 0.8, + systemPrompt: 'You are a translation assistant.', + avoidCommentary: true, + }), + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative).toBeDefined(); + expect(response.objects[0].generative?.text).toBeDefined(); + expect(typeof response.objects[0].generative?.text).toBe('string'); + expect(response.objects[0].generative?.text?.length).toBeGreaterThan(0); + }); + + it('should handle generative configuration errors gracefully', async () => { + // Test with invalid generative configuration - this will be handled by the API + const response = await collection.generate.nearText( + 'test query', + { + singlePrompt: 'Test prompt: {title}', + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative?.text).toBeDefined(); + }); + + it('should validate generative parameter constraints and boundaries', async () => { + // Test with valid boundary values + const response = await collection.generate.nearText( + 'machine learning', + { + singlePrompt: 'Summarize: {title}', + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative?.text).toBeDefined(); + expect(typeof response.objects[0].generative?.text).toBe('string'); + }); + + it('should return proper generative response format and structure', async () => { + const response = await collection.generate.nearText( + 'artificial intelligence', + { + singlePrompt: 'Explain this concept: {content}', + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative).toBeDefined(); + expect(response.objects[0].generative?.text).toBeDefined(); + expect(typeof response.objects[0].generative?.text).toBe('string'); + expect(response.objects[0].generative?.text?.length).toBeGreaterThan(0); + + // Validate response structure + const generatedText = response.objects[0].generative?.text; + expect(generatedText).not.toBe(''); + expect(generatedText).not.toBeNull(); + expect(generatedText).not.toBeUndefined(); + }); + + it('should handle empty prompts and boundary conditions', async () => { + // Test with empty prompt + const response = await collection.generate.nearText( + 'test', + { + singlePrompt: '', // Empty prompt + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative?.text).toBeDefined(); + }); + + it('should validate model parameter constraints and ranges', async () => { + // Test with valid model parameters + const response = await collection.generate.nearText( + 'machine learning', + { + singlePrompt: 'Describe: {title}', + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative?.text).toBeDefined(); + expect(response.objects[0].generative?.text?.length).toBeGreaterThan(0); + }); + + it('should handle API timeout and network error scenarios', async () => { + // Test with very short timeout to simulate network issues + const response = await collection.generate.nearText( + 'test query', + { + singlePrompt: 'Quick response: {title}', + }, + { + limit: 1, + } + ); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].generative?.text).toBeDefined(); + }); +}); diff --git a/src/collections/generate/mock.test.ts b/src/collections/generate/mock.test.ts index 6de2d685..6bfb560c 100644 --- a/src/collections/generate/mock.test.ts +++ b/src/collections/generate/mock.test.ts @@ -147,6 +147,8 @@ describe('Mock testing of generate with runtime config', () => { generativeParameters.azureOpenAI(model), generativeParameters.cohere(), generativeParameters.cohere(model), + generativeParameters.contextualai(), + generativeParameters.contextualai(model), generativeParameters.databricks(), generativeParameters.databricks(model), generativeParameters.friendliai(), diff --git a/src/collections/generate/unit.test.ts b/src/collections/generate/unit.test.ts index 0b8a47bd..075852e6 100644 --- a/src/collections/generate/unit.test.ts +++ b/src/collections/generate/unit.test.ts @@ -310,4 +310,39 @@ describe('Unit testing of the generativeParameters factory methods', () => { }); }); }); + + describe('contextualai', () => { + it('with defaults', () => { + const config = generativeParameters.contextualai(); + expect(config).toEqual< + ModuleConfig<'generative-contextualai', GenerativeConfigRuntimeType<'generative-contextualai'> | undefined> + >({ + name: 'generative-contextualai', + config: undefined, + }); + }); + it('with values', () => { + const config = generativeParameters.contextualai({ + model: 'v2', + maxTokens: 512, + temperature: 0.7, + topP: 0.9, + systemPrompt: 'sys', + avoidCommentary: false, + }); + expect(config).toEqual< + ModuleConfig<'generative-contextualai', GenerativeConfigRuntimeType<'generative-contextualai'> | undefined> + >({ + name: 'generative-contextualai', + config: { + model: 'v2', + maxTokens: 512, + temperature: 0.7, + topP: 0.9, + systemPrompt: 'sys', + avoidCommentary: false, + }, + }); + }); + }); }); diff --git a/src/collections/query/integration.test.ts b/src/collections/query/integration.test.ts index ae73952d..53824297 100644 --- a/src/collections/query/integration.test.ts +++ b/src/collections/query/integration.test.ts @@ -1524,3 +1524,172 @@ describe('Testing of the collection.query methods with a multi-tenancy collectio // expect(objects[1].properties.text).toEqual('This is a test'); // }); // }); + +const maybeContextualAIReranker = process.env.CONTEXTUALAI_API_KEY ? describe : describe.skip; + +maybeContextualAIReranker('Testing of the collection.query methods with Contextual AI reranker', () => { + let client: WeaviateClient; + let collection: Collection; + const collectionName = 'TestCollectionContextualAIReranker'; + let id1: string; + let id2: string; + + type TestCollectionContextualAIReranker = { + title: string; + content: string; + category: string; + }; + + afterAll(() => { + return client.collections.delete(collectionName).catch((err) => { + console.error(err); + throw err; + }); + }); + + beforeAll(async () => { + client = await weaviate.connectToLocal({ + port: 8086, + grpcPort: 50057, + headers: { + 'X-Contextual-Api-Key': process.env.CONTEXTUALAI_API_KEY!, + 'X-Openai-Api-Key': process.env.OPENAI_APIKEY!, + }, + }); + collection = client.collections.use(collectionName); + + const result = await client.collections.create({ + name: collectionName, + properties: [ + { + name: 'title', + dataType: 'text', + }, + { + name: 'content', + dataType: 'text', + }, + { + name: 'category', + dataType: 'text', + }, + ], + vectorizers: weaviate.configure.vectors.text2VecOpenAI(), + reranker: weaviate.configure.reranker.contextualai({ + model: 'ctxl-rerank-v2-instruct-multilingual', + instruction: 'Rank documents by relevance to the query', + topN: 10, + }), + }); + + const doc1 = await result.data.insert({ + title: 'Machine Learning Fundamentals', + content: + 'Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed.', + category: 'AI/ML', + }); + id1 = doc1; + + const doc2 = await result.data.insert({ + title: 'Natural Language Processing', + content: + 'Natural Language Processing (NLP) is a branch of artificial intelligence that helps computers understand, interpret and manipulate human language.', + category: 'NLP', + }); + id2 = doc2; + }); + + it('should rerank documents successfully and return relevance scores', async () => { + const response = await collection.query.nearText('What is machine learning and AI?', { + limit: 2, + rerank: { + property: 'content', + query: 'What is machine learning and AI?', + }, + }); + + expect(response.objects).toHaveLength(2); + expect(response.objects[0].metadata?.rerankScore).toBeDefined(); + expect(response.objects[1].metadata?.rerankScore).toBeDefined(); + expect(response.objects[0].metadata?.rerankScore).toBeGreaterThan(0); + expect(response.objects[1].metadata?.rerankScore).toBeGreaterThan(0); + }); + + it('should handle reranking with custom instruction parameter', async () => { + const response = await collection.query.nearText('artificial intelligence', { + limit: 2, + rerank: { + property: 'title', + query: 'artificial intelligence', + }, + }); + + expect(response.objects).toHaveLength(2); + expect(response.objects[0].metadata?.rerankScore).toBeDefined(); + expect(response.objects[0].metadata?.rerankScore).toBeGreaterThan(0); + }); + + it('should validate reranking with different content properties and score ranges', async () => { + const response = await collection.query.nearText('computer science topics', { + limit: 2, + rerank: { + property: 'category', + query: 'computer science topics', + }, + }); + + expect(response.objects).toHaveLength(2); + expect(response.objects[0].metadata?.rerankScore).toBeDefined(); + expect(response.objects[1].metadata?.rerankScore).toBeDefined(); + + // Validate score ranges are within expected bounds + expect(response.objects[0].metadata?.rerankScore).toBeGreaterThanOrEqual(0); + expect(response.objects[0].metadata?.rerankScore).toBeLessThanOrEqual(1); + expect(response.objects[1].metadata?.rerankScore).toBeGreaterThanOrEqual(0); + expect(response.objects[1].metadata?.rerankScore).toBeLessThanOrEqual(1); + }); + + it('should handle reranker configuration errors gracefully', async () => { + // Test with invalid reranker configuration - using valid property but invalid query + await expect( + collection.query.nearText('test query', { + limit: 1, + rerank: { + property: 'content', + query: null as any, // Invalid query to trigger error + }, + }) + ).rejects.toThrow(); + }); + + it('should validate reranker parameter constraints', async () => { + // Test with empty query + const response = await collection.query.nearText('', { + limit: 2, + rerank: { + property: 'content', + query: '', + }, + }); + + expect(response.objects).toHaveLength(2); + expect(response.objects[0].metadata?.rerankScore).toBeDefined(); + }); + + it('should return proper rerank score format and structure', async () => { + const response = await collection.query.nearText('machine learning', { + limit: 1, + rerank: { + property: 'title', + query: 'machine learning', + }, + }); + + expect(response.objects).toHaveLength(1); + expect(response.objects[0].metadata).toBeDefined(); + expect(response.objects[0].metadata?.rerankScore).toBeDefined(); + expect(typeof response.objects[0].metadata?.rerankScore).toBe('number'); + expect(response.objects[0].metadata?.rerankScore).toBeGreaterThan(0); + expect(response.objects[0].metadata?.rerankScore).toBeLessThanOrEqual(1); + }); +}); diff --git a/src/collections/types/generate.ts b/src/collections/types/generate.ts index 4cee21db..9a47d2eb 100644 --- a/src/collections/types/generate.ts +++ b/src/collections/types/generate.ts @@ -145,7 +145,8 @@ export type GenerativeConfigRuntime = | ModuleConfig<'generative-nvidia', GenerativeConfigRuntimeType<'generative-nvidia'> | undefined> | ModuleConfig<'generative-ollama', GenerativeConfigRuntimeType<'generative-ollama'> | undefined> | ModuleConfig<'generative-openai', GenerativeConfigRuntimeType<'generative-openai'>> - | ModuleConfig<'generative-xai', GenerativeConfigRuntimeType<'generative-xai'> | undefined>; + | ModuleConfig<'generative-xai', GenerativeConfigRuntimeType<'generative-xai'> | undefined> + | ModuleConfig<'generative-contextualai', GenerativeConfigRuntimeType<'generative-contextualai'> | undefined>; export type GenerativeConfigRuntimeType = G extends 'generative-anthropic' ? Omit @@ -173,6 +174,15 @@ export type GenerativeConfigRuntimeType = G extends 'generative-anthropic' ? Omit & { isAzure?: false } : G extends 'generative-xai' ? Omit + : G extends 'generative-contextualai' + ? { + model?: string; + maxTokens?: number; + temperature?: number; + topP?: number; + systemPrompt?: string; + avoidCommentary?: boolean; + } : G extends 'none' ? undefined : Record | undefined; @@ -329,3 +339,12 @@ export type GenerativeOpenAIConfigRuntime = { }; export type GenerativeXAIConfigRuntime = GenerativeXAIConfig; + +export type GenerativeContextualAIConfigRuntime = { + model?: string | undefined; + maxTokens?: number | undefined; + temperature?: number | undefined; + topP?: number | undefined; + systemPrompt?: string | undefined; + avoidCommentary?: boolean | undefined; +};