Refactor ai function calls

This commit is contained in:
Sam Becker 2024-06-11 17:45:17 -05:00
parent d9fa68cbaa
commit 95746b750d

View File

@ -13,7 +13,6 @@ const openai = AI_TEXT_GENERATION_ENABLED
? createOpenAI({ apiKey: process.env.OPENAI_SECRET_KEY })
: undefined;
// Allows 100 requests per hour
const ratelimit = HAS_VERCEL_KV
? new Ratelimit({
redis: kv,
@ -21,10 +20,8 @@ const ratelimit = HAS_VERCEL_KV
})
: undefined;
export const streamOpenAiImageQuery = async (
imageBase64: string,
query: string,
) => {
// Allows 100 requests per hour
const checkRateLimitAndBailIfNecessary = async () => {
if (ratelimit) {
let success = false;
try {
@ -38,12 +35,15 @@ export const streamOpenAiImageQuery = async (
throw new Error('OpenAI rate limit exceeded');
}
}
};
const stream = createStreamableValue('');
if (openai) {
(async () => {
const { textStream } = await streamText({
const getImageTextArgs = (
imageBase64: string,
query: string,
): (
Parameters<typeof streamText>[0] &
Parameters<typeof generateText>[0]
) | undefined => openai ? {
model: openai('gpt-4o'),
messages: [{
'role': 'user',
@ -57,7 +57,21 @@ export const streamOpenAiImageQuery = async (
},
],
}],
});
} : undefined;
export const streamOpenAiImageQuery = async (
imageBase64: string,
query: string,
) => {
await checkRateLimitAndBailIfNecessary();
const stream = createStreamableValue('');
const args = getImageTextArgs(imageBase64, query);
if (args) {
(async () => {
const { textStream } = await streamText(args);
for await (const delta of textStream) {
stream.update(delta);
}
@ -72,53 +86,19 @@ export const generateOpenAiImageQuery = async (
imageBase64: string,
query: string,
) => {
if (ratelimit) {
let success = false;
try {
success = (await ratelimit.limit(RATE_LIMIT_IDENTIFIER)).success;
} catch (e: any) {
console.error('Failed to rate limit OpenAI', e);
throw new Error('Failed to rate limit OpenAI');
}
if (!success) {
console.error('OpenAI rate limit exceeded');
throw new Error('OpenAI rate limit exceeded');
}
}
await checkRateLimitAndBailIfNecessary();
if (openai) {
return generateText({
model: openai('gpt-4o'),
messages: [{
'role': 'user',
'content': [
{
'type': 'text',
'text': query,
}, {
'type': 'image',
'image': removeBase64Prefix(imageBase64),
},
],
}],
}).then(({ text }) => text);
const args = getImageTextArgs(imageBase64, query);
if (args) {
return generateText(args)
.then(({ text }) => text);
}
};
export const testOpenAiConnection = async () => {
if (ratelimit) {
let success = false;
try {
success = (await ratelimit.limit(RATE_LIMIT_IDENTIFIER)).success;
} catch (e: any) {
console.error('Failed to rate limit OpenAI', e);
throw new Error('Failed to rate limit OpenAI');
}
if (!success) {
console.error('OpenAI rate limit exceeded');
throw new Error('OpenAI rate limit exceeded');
}
}
await checkRateLimitAndBailIfNecessary();
if (openai) {
return generateText({
model: openai('gpt-4o'),
@ -131,6 +111,6 @@ export const testOpenAiConnection = async () => {
},
],
}],
}).then(({ text }) => text);
});
}
};