diff --git a/site/src/pages/TasksPage/TasksPage.stories.tsx b/site/src/pages/TasksPage/TasksPage.stories.tsx index 1b1770f586768..0bdc9d27a7eed 100644 --- a/site/src/pages/TasksPage/TasksPage.stories.tsx +++ b/site/src/pages/TasksPage/TasksPage.stories.tsx @@ -4,12 +4,14 @@ import { API } from "api/api"; import { MockUsers } from "pages/UsersPage/storybookData/users"; import { reactRouterParameters } from "storybook-addon-remix-react-router"; import { + MockAIPromptPresets, + MockNewTaskData, + MockPresets, + MockTasks, MockTemplate, MockTemplateVersionExternalAuthGithub, MockTemplateVersionExternalAuthGithubAuthenticated, MockUserOwner, - MockWorkspace, - MockWorkspaceAppStatus, mockApiError, } from "testHelpers/entities"; import { @@ -31,6 +33,7 @@ const meta: Meta = { }, beforeEach: () => { spyOn(API, "getTemplateVersionExternalAuth").mockResolvedValue([]); + spyOn(API, "getTemplateVersionPresets").mockResolvedValue(null); spyOn(API, "getUsers").mockResolvedValue({ users: MockUsers, count: MockUsers.length, @@ -53,7 +56,7 @@ type Story = StoryObj; export const LoadingAITemplates: Story = { beforeEach: () => { spyOn(data, "fetchAITemplates").mockImplementation( - () => new Promise((res) => 1000 * 60 * 60), + () => new Promise(() => 1000 * 60 * 60), ); }, }; @@ -79,7 +82,7 @@ export const LoadingTasks: Story = { beforeEach: () => { spyOn(data, "fetchAITemplates").mockResolvedValue([MockTemplate]); spyOn(data, "fetchTasks").mockImplementation( - () => new Promise((res) => 1000 * 60 * 60), + () => new Promise(() => 1000 * 60 * 60), ); }, play: async ({ canvasElement, step }) => { @@ -119,15 +122,77 @@ export const LoadedTasks: Story = { }, }; -const newTaskData = { - prompt: "Create a new task", - workspace: { - ...MockWorkspace, - id: "workspace-4", - latest_app_status: { - ...MockWorkspaceAppStatus, - message: "Task created successfully!", - }, +export const LoadedTasksWithPresets: Story = { + decorators: [withProxyProvider()], + beforeEach: () => { + const mockTemplateWithPresets = { + ...MockTemplate, + id: "test-template-2", + name: "template-with-presets", + display_name: "Template with Presets", + }; + + spyOn(data, "fetchAITemplates").mockResolvedValue([ + MockTemplate, + mockTemplateWithPresets, + ]); + spyOn(data, "fetchTasks").mockResolvedValue(MockTasks); + spyOn(API, "getTemplateVersionPresets").mockImplementation( + async (versionId) => { + // Return presets only for the second template + if (versionId === mockTemplateWithPresets.active_version_id) { + return MockPresets; + } + return null; + }, + ); + }, +}; + +export const LoadedTasksWithAIPromptPresets: Story = { + decorators: [withProxyProvider()], + beforeEach: () => { + const mockTemplateWithPresets = { + ...MockTemplate, + id: "test-template-2", + name: "template-with-presets", + display_name: "Template with AI Prompt Presets", + }; + + spyOn(data, "fetchAITemplates").mockResolvedValue([ + MockTemplate, + mockTemplateWithPresets, + ]); + spyOn(data, "fetchTasks").mockResolvedValue(MockTasks); + spyOn(API, "getTemplateVersionPresets").mockImplementation( + async (versionId) => { + // Return presets only for the second template + if (versionId === mockTemplateWithPresets.active_version_id) { + return MockAIPromptPresets; + } + return null; + }, + ); + }, +}; + +export const LoadedTasksEdgeCases: Story = { + decorators: [withProxyProvider()], + beforeEach: () => { + spyOn(data, "fetchAITemplates").mockResolvedValue([MockTemplate]); + spyOn(data, "fetchTasks").mockResolvedValue(MockTasks); + + // Test various edge cases for presets + spyOn(API, "getTemplateVersionPresets").mockImplementation(async () => { + return [ + { + ID: "malformed", + Name: "Malformed Preset", + Default: true, + }, + // biome-ignore lint/suspicious/noExplicitAny: Testing malformed data edge cases + ] as any; + }); }, }; @@ -154,15 +219,15 @@ export const CreateTaskSuccessfully: Story = { spyOn(data, "fetchAITemplates").mockResolvedValue([MockTemplate]); spyOn(data, "fetchTasks") .mockResolvedValueOnce(MockTasks) - .mockResolvedValue([newTaskData, ...MockTasks]); - spyOn(data, "createTask").mockResolvedValue(newTaskData); + .mockResolvedValue([MockNewTaskData, ...MockTasks]); + spyOn(data, "createTask").mockResolvedValue(MockNewTaskData); }, play: async ({ canvasElement, step }) => { const canvas = within(canvasElement); await step("Run task", async () => { const prompt = await canvas.findByLabelText(/prompt/i); - await userEvent.type(prompt, newTaskData.prompt); + await userEvent.type(prompt, MockNewTaskData.prompt); const submitButton = canvas.getByRole("button", { name: /run task/i }); await waitFor(() => expect(submitButton).toBeEnabled()); await userEvent.click(submitButton); @@ -208,8 +273,8 @@ export const WithAuthenticatedExternalAuth: Story = { beforeEach: () => { spyOn(data, "fetchTasks") .mockResolvedValueOnce(MockTasks) - .mockResolvedValue([newTaskData, ...MockTasks]); - spyOn(data, "createTask").mockResolvedValue(newTaskData); + .mockResolvedValue([MockNewTaskData, ...MockTasks]); + spyOn(data, "createTask").mockResolvedValue(MockNewTaskData); spyOn(API, "getTemplateVersionExternalAuth").mockResolvedValue([ MockTemplateVersionExternalAuthGithubAuthenticated, ]); @@ -235,8 +300,8 @@ export const MissingExternalAuth: Story = { beforeEach: () => { spyOn(data, "fetchTasks") .mockResolvedValueOnce(MockTasks) - .mockResolvedValue([newTaskData, ...MockTasks]); - spyOn(data, "createTask").mockResolvedValue(newTaskData); + .mockResolvedValue([MockNewTaskData, ...MockTasks]); + spyOn(data, "createTask").mockResolvedValue(MockNewTaskData); spyOn(API, "getTemplateVersionExternalAuth").mockResolvedValue([ MockTemplateVersionExternalAuthGithub, ]); @@ -246,7 +311,7 @@ export const MissingExternalAuth: Story = { await step("Submit is disabled", async () => { const prompt = await canvas.findByLabelText(/prompt/i); - await userEvent.type(prompt, newTaskData.prompt); + await userEvent.type(prompt, MockNewTaskData.prompt); const submitButton = canvas.getByRole("button", { name: /run task/i }); expect(submitButton).toBeDisabled(); }); @@ -262,8 +327,8 @@ export const ExternalAuthError: Story = { beforeEach: () => { spyOn(data, "fetchTasks") .mockResolvedValueOnce(MockTasks) - .mockResolvedValue([newTaskData, ...MockTasks]); - spyOn(data, "createTask").mockResolvedValue(newTaskData); + .mockResolvedValue([MockNewTaskData, ...MockTasks]); + spyOn(data, "createTask").mockResolvedValue(MockNewTaskData); spyOn(API, "getTemplateVersionExternalAuth").mockRejectedValue( mockApiError({ message: "Failed to load external auth", @@ -275,7 +340,7 @@ export const ExternalAuthError: Story = { await step("Submit is disabled", async () => { const prompt = await canvas.findByLabelText(/prompt/i); - await userEvent.type(prompt, newTaskData.prompt); + await userEvent.type(prompt, MockNewTaskData.prompt); const submitButton = canvas.getByRole("button", { name: /run task/i }); expect(submitButton).toBeDisabled(); }); @@ -308,35 +373,3 @@ export const NonAdmin: Story = { }); }, }; - -const MockTasks = [ - { - workspace: { - ...MockWorkspace, - latest_app_status: MockWorkspaceAppStatus, - }, - prompt: "Create competitors page", - }, - { - workspace: { - ...MockWorkspace, - id: "workspace-2", - latest_app_status: { - ...MockWorkspaceAppStatus, - message: "Avatar size fixed!", - }, - }, - prompt: "Fix user avatar size", - }, - { - workspace: { - ...MockWorkspace, - id: "workspace-3", - latest_app_status: { - ...MockWorkspaceAppStatus, - message: "Accessibility issues fixed!", - }, - }, - prompt: "Fix accessibility issues", - }, -]; diff --git a/site/src/pages/TasksPage/TasksPage.tsx b/site/src/pages/TasksPage/TasksPage.tsx index d678098affd17..4866dbfb49222 100644 --- a/site/src/pages/TasksPage/TasksPage.tsx +++ b/site/src/pages/TasksPage/TasksPage.tsx @@ -2,7 +2,11 @@ import Skeleton from "@mui/material/Skeleton"; import { API } from "api/api"; import { getErrorDetail, getErrorMessage } from "api/errors"; import { disabledRefetchOptions } from "api/queries/util"; -import type { Template, TemplateVersionExternalAuth } from "api/typesGenerated"; +import type { + Preset, + Template, + TemplateVersionExternalAuth, +} from "api/typesGenerated"; import { ErrorAlert } from "components/Alert/ErrorAlert"; import { Avatar } from "components/Avatar/Avatar"; import { AvatarData } from "components/Avatar/AvatarData"; @@ -36,6 +40,7 @@ import { TableRowSkeleton, } from "components/TableLoader/TableLoader"; +import { templateVersionPresets } from "api/queries/templates"; import { ExternalImage } from "components/ExternalImage/ExternalImage"; import { FeatureStageBadge } from "components/FeatureStageBadge/FeatureStageBadge"; import { @@ -50,7 +55,7 @@ import { RedoIcon, RotateCcwIcon, SendIcon } from "lucide-react"; import { AI_PROMPT_PARAMETER_NAME, type Task } from "modules/tasks/tasks"; import { WorkspaceAppStatus } from "modules/workspaces/WorkspaceAppStatus/WorkspaceAppStatus"; import { generateWorkspaceName } from "modules/workspaces/generateWorkspaceName"; -import { type FC, type ReactNode, useState } from "react"; +import { type FC, type ReactNode, useEffect, useState } from "react"; import { Helmet } from "react-helmet-async"; import { useMutation, useQuery, useQueryClient } from "react-query"; import { Link as RouterLink, useNavigate } from "react-router-dom"; @@ -210,7 +215,11 @@ const TaskFormSection: FC<{ ); }; -type CreateTaskMutationFnProps = { prompt: string; templateVersionId: string }; +type CreateTaskMutationFnProps = { + prompt: string; + templateVersionId: string; + presetId: string | null; +}; type TaskFormProps = { templates: Template[]; @@ -223,15 +232,49 @@ const TaskForm: FC = ({ templates, onSuccess }) => { const [selectedTemplateId, setSelectedTemplateId] = useState( templates[0].id, ); + const [selectedPresetId, setSelectedPresetId] = useState(null); const selectedTemplate = templates.find( (t) => t.id === selectedTemplateId, ) as Template; + const { externalAuth, externalAuthError, isPollingExternalAuth, isLoadingExternalAuth, } = useExternalAuth(selectedTemplate.active_version_id); + + // Fetch presets when template changes + const { data: presetsData, isLoading: isLoadingPresets } = useQuery< + Preset[] | null, + Error + >(templateVersionPresets(selectedTemplate.active_version_id)); + + // Handle preset selection when data changes + useEffect(() => { + if (presetsData === undefined) { + // Still loading + return; + } + + if (!presetsData || presetsData.length === 0) { + setSelectedPresetId(null); + return; + } + + // Always select the default preset when new data arrives + const defaultPreset = presetsData.find((p: Preset) => p.Default); + const defaultPresetID = defaultPreset?.ID || null; + setSelectedPresetId(defaultPresetID); + }, [presetsData]); + + // Extract AI prompt from selected preset + const selectedPreset = presetsData?.find((p) => p.ID === selectedPresetId); + const presetAIPrompt = selectedPreset?.Parameters?.find( + (param) => param.Name === AI_PROMPT_PARAMETER_NAME, + )?.Value; + const isPromptReadOnly = !!presetAIPrompt; + const missedExternalAuth = externalAuth?.filter( (auth) => !auth.optional && !auth.authenticated, ); @@ -243,8 +286,9 @@ const TaskForm: FC = ({ templates, onSuccess }) => { mutationFn: async ({ prompt, templateVersionId, + presetId, }: CreateTaskMutationFnProps) => - data.createTask(prompt, user.id, templateVersionId), + data.createTask(prompt, user.id, templateVersionId, presetId), onSuccess: async (task) => { await queryClient.invalidateQueries({ queryKey: ["tasks"], @@ -258,13 +302,13 @@ const TaskForm: FC = ({ templates, onSuccess }) => { const form = e.currentTarget; const formData = new FormData(form); - const prompt = formData.get("prompt") as string; - const templateID = formData.get("templateID") as string; + const prompt = presetAIPrompt || (formData.get("prompt") as string); try { await createTaskMutation.mutateAsync({ prompt, templateVersionId: selectedTemplate.active_version_id, + presetId: selectedPresetId, }); } catch (error) { const message = getErrorMessage(error, "Error creating task"); @@ -285,39 +329,113 @@ const TaskForm: FC = ({ templates, onSuccess }) => { className="border border-border border-solid rounded-lg p-4" disabled={createTaskMutation.isPending} > -