From 95746b750d5a06316173471045ee8a8e9bf1df84 Mon Sep 17 00:00:00 2001 From: Sam Becker Date: Tue, 11 Jun 2024 17:45:17 -0500 Subject: [PATCH] Refactor ai function calls --- src/services/openai.ts | 108 +++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 64 deletions(-) diff --git a/src/services/openai.ts b/src/services/openai.ts index 83c625d6..443ea7d4 100644 --- a/src/services/openai.ts +++ b/src/services/openai.ts @@ -13,7 +13,6 @@ const openai = AI_TEXT_GENERATION_ENABLED ? createOpenAI({ apiKey: process.env.OPENAI_SECRET_KEY }) : undefined; -// Allows 100 requests per hour const ratelimit = HAS_VERCEL_KV ? new Ratelimit({ redis: kv, @@ -21,10 +20,8 @@ const ratelimit = HAS_VERCEL_KV }) : undefined; -export const streamOpenAiImageQuery = async ( - imageBase64: string, - query: string, -) => { +// Allows 100 requests per hour +const checkRateLimitAndBailIfNecessary = async () => { if (ratelimit) { let success = false; try { @@ -38,26 +35,43 @@ export const streamOpenAiImageQuery = async ( throw new Error('OpenAI rate limit exceeded'); } } +}; + +const getImageTextArgs = ( + imageBase64: string, + query: string, +): ( + Parameters[0] & + Parameters[0] +) | undefined => openai ? { + model: openai('gpt-4o'), + messages: [{ + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': query, + }, { + 'type': 'image', + 'image': removeBase64Prefix(imageBase64), + }, + ], + }], +} : undefined; + +export const streamOpenAiImageQuery = async ( + imageBase64: string, + query: string, +) => { + await checkRateLimitAndBailIfNecessary(); const stream = createStreamableValue(''); - if (openai) { + const args = getImageTextArgs(imageBase64, query); + + if (args) { (async () => { - const { textStream } = await streamText({ - model: openai('gpt-4o'), - messages: [{ - 'role': 'user', - 'content': [ - { - 'type': 'text', - 'text': query, - }, { - 'type': 'image', - 'image': removeBase64Prefix(imageBase64), - }, - ], - }], - }); + const { textStream } = await streamText(args); for await (const delta of textStream) { stream.update(delta); } @@ -72,53 +86,19 @@ export const generateOpenAiImageQuery = async ( imageBase64: string, query: string, ) => { - if (ratelimit) { - let success = false; - try { - success = (await ratelimit.limit(RATE_LIMIT_IDENTIFIER)).success; - } catch (e: any) { - console.error('Failed to rate limit OpenAI', e); - throw new Error('Failed to rate limit OpenAI'); - } - if (!success) { - console.error('OpenAI rate limit exceeded'); - throw new Error('OpenAI rate limit exceeded'); - } - } + await checkRateLimitAndBailIfNecessary(); - if (openai) { - return generateText({ - model: openai('gpt-4o'), - messages: [{ - 'role': 'user', - 'content': [ - { - 'type': 'text', - 'text': query, - }, { - 'type': 'image', - 'image': removeBase64Prefix(imageBase64), - }, - ], - }], - }).then(({ text }) => text); + const args = getImageTextArgs(imageBase64, query); + + if (args) { + return generateText(args) + .then(({ text }) => text); } }; export const testOpenAiConnection = async () => { - if (ratelimit) { - let success = false; - try { - success = (await ratelimit.limit(RATE_LIMIT_IDENTIFIER)).success; - } catch (e: any) { - console.error('Failed to rate limit OpenAI', e); - throw new Error('Failed to rate limit OpenAI'); - } - if (!success) { - console.error('OpenAI rate limit exceeded'); - throw new Error('OpenAI rate limit exceeded'); - } - } + await checkRateLimitAndBailIfNecessary(); + if (openai) { return generateText({ model: openai('gpt-4o'), @@ -131,6 +111,6 @@ export const testOpenAiConnection = async () => { }, ], }], - }).then(({ text }) => text); + }); } };