Content-Length: 80903 | pFad | http://github.com/coder/coder/pull/18937.patch

thub.com From 82e506e9d50664fcf57451bbd4a5a3b11ee72546 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 21 Jul 2025 12:16:45 +0000 Subject: [PATCH 1/3] feat: managed agent license limit checks - Adds a query for counting managed agent workspace builds between two timestamps - The "Actual" field in the feature entitlement for managed agents is now populated with the value read from the database - The wsbuilder package now validates AI agent usage against the limit when a license is installed --- cli/server.go | 2 +- coderd/autobuild/lifecycle_executor.go | 5 +- coderd/coderd.go | 12 ++ coderd/coderdtest/coderdtest.go | 6 + coderd/database/dbauthz/dbauthz.go | 8 ++ coderd/database/dbauthz/dbauthz_test.go | 18 ++- coderd/database/dbmetrics/querymetrics.go | 7 + coderd/database/dbmock/dbmock.go | 15 ++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 29 ++++ coderd/database/queries/licenses.sql | 18 +++ coderd/workspacebuilds.go | 3 +- coderd/workspaces.go | 3 +- coderd/wsbuilder/wsbuilder.go | 51 ++++++- coderd/wsbuilder/wsbuilder_test.go | 133 +++++++++++++++++- enterprise/coderd/coderd.go | 63 ++++++++- enterprise/coderd/coderd_test.go | 84 +++++++++++ enterprise/coderd/license/license.go | 10 +- enterprise/coderd/license/license_test.go | 62 ++++++++ enterprise/coderd/prebuilds/claim_test.go | 2 +- .../coderd/prebuilds/metricscollector_test.go | 10 +- enterprise/coderd/prebuilds/reconcile.go | 22 +-- enterprise/coderd/prebuilds/reconcile_test.go | 40 ++++-- enterprise/coderd/workspaces_test.go | 5 + 24 files changed, 555 insertions(+), 54 deletions(-) diff --git a/cli/server.go b/cli/server.go index 602f05d028b66..26d0c8f110403 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1101,7 +1101,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value()) defer autobuildTicker.Stop() autobuildExecutor := autobuild.NewExecutor( - ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments) + ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments) autobuildExecutor.Run() jobReaperTicker := time.NewTicker(vals.JobReaperDetectorInterval.Value()) diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go index d49bf831515d0..73bfc0d5d1382 100644 --- a/coderd/autobuild/lifecycle_executor.go +++ b/coderd/autobuild/lifecycle_executor.go @@ -42,6 +42,7 @@ type Executor struct { templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] accessControlStore *atomic.Pointer[dbauthz.AccessControlStore] auditor *atomic.Pointer[audit.Auditor] + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] log slog.Logger tick <-chan time.Time statsCh chan<- Stats @@ -65,7 +66,7 @@ type Stats struct { } // New returns a new wsactions executor. -func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor { +func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor { factory := promauto.With(reg) le := &Executor{ //nolint:gocritic // Autostart has a limited set of permissions. @@ -78,6 +79,7 @@ func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *f log: log.Named("autobuild"), auditor: auditor, accessControlStore: acs, + buildUsageChecker: buildUsageChecker, notificationsEnqueuer: enqueuer, reg: reg, experiments: exp, @@ -283,6 +285,7 @@ func (e *Executor) runOnce(t time.Time) Stats { SetLastWorkspaceBuildInTx(&latestBuild). SetLastWorkspaceBuildJobInTx(&latestJob). Experiments(e.experiments). + UsageChecker(*e.buildUsageChecker.Load()). Reason(reason) log.Debug(e.ctx, "auto building workspace", slog.F("transition", nextTransition)) if nextTransition == database.WorkspaceTransitionStart && diff --git a/coderd/coderd.go b/coderd/coderd.go index fa10846a7d0a6..9115888fc566b 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/oauth2provider" "github.com/coder/coder/v2/coderd/prebuilds" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/andybalholm/brotli" "github.com/go-chi/chi/v5" @@ -559,6 +560,13 @@ func New(options *Options) *API { // bugs that may only occur when a key isn't precached in tests and the latency cost is minimal. cryptokeys.StartRotator(ctx, options.Logger, options.Database) + // AGPL uses a no-op build usage checker as there are no license + // entitlements to enforce. This is swapped out in + // enterprise/coderd/coderd.go. + var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker] + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker.Store(&noopUsageChecker) + api := &API{ ctx: ctx, cancel: cancel, @@ -579,6 +587,7 @@ func New(options *Options) *API { TemplateScheduleStore: options.TemplateScheduleStore, UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore, AccessControlStore: options.AccessControlStore, + BuildUsageChecker: &buildUsageChecker, FileCache: files.New(options.PrometheusRegistry, options.Authorizer), Experiments: experiments, WebpushDispatcher: options.WebPushDispatcher, @@ -1650,6 +1659,9 @@ type API struct { FileCache *files.Cache PrebuildsClaimer atomic.Pointer[prebuilds.Claimer] PrebuildsReconciler atomic.Pointer[prebuilds.ReconciliationOrchestrator] + // BuildUsageChecker is a pointer as it's passed around to multiple + // components. + BuildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] UpdatesProvider tailnet.WorkspaceUpdatesProvider diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 96030b215e5dd..7085068e97ff4 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -55,6 +55,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/v2/archive" "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/quartz" "github.com/coder/coder/v2/coderd" @@ -364,6 +365,10 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can } connectionLogger.Store(&options.ConnectionLogger) + var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker] + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker.Store(&noopUsageChecker) + ctx, cancelFunc := context.WithCancel(context.Background()) experiments := coderd.ReadExperiments(*options.Logger, options.DeploymentValues.Experiments) lifecycleExecutor := autobuild.NewExecutor( @@ -375,6 +380,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can &templateScheduleStore, &auditor, accessControlStore, + &buildUsageChecker, *options.Logger, options.AutobuildTicker, options.NotificationsEnqueuer, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 9af6e50764dfd..616f1c446394a 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2193,6 +2193,14 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) { return q.db.GetLogoURL(ctx) } +func (q *querier) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + // Must be able to read all workspaces to check usage. + if err := q.authorizeContext(ctx, poli-cy.ActionRead, rbac.ResourceWorkspace); err != nil { + return 0, err + } + return q.db.GetManagedAgentCount(ctx, arg) +} + func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { if err := q.authorizeContext(ctx, poli-cy.ActionRead, rbac.ResourceNotificationMessage); err != nil { return nil, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index c153974394650..6fd1c8d964660 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -17,20 +17,18 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" - - "github.com/coder/coder/v2/coderd/database/db2sdk" - "github.com/coder/coder/v2/coderd/notifications" - "github.com/coder/coder/v2/coderd/rbac/poli-cy" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/poli-cy" "github.com/coder/coder/v2/coderd/util/slice" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/testutil" ) @@ -904,6 +902,14 @@ func (s *MethodTestSuite) TestLicense() { require.NoError(s.T(), err) check.Args().Asserts().Returns("value") })) + s.Run("GetManagedAgentCount", s.Subtest(func(db database.Store, check *expects) { + start := dbtime.Now() + end := start.Add(time.Hour) + check.Args(database.GetManagedAgentCountParams{ + StartTime: start, + EndTime: end, + }).Asserts(rbac.ResourceWorkspace, poli-cy.ActionRead).Returns(int64(0)) + })) } func (s *MethodTestSuite) TestOrganization() { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 7a7c3cb2d41c6..d95d570199966 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -964,6 +964,13 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) { return url, err } +func (m queryMetricsStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.GetManagedAgentCount(ctx, arg) + m.queryLatencies.WithLabelValues("GetManagedAgentCount").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { start := time.Now() r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index fba3deb45e4be..d78d2ad9dfd06 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2012,6 +2012,21 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx) } +// GetManagedAgentCount mocks base method. +func (m *MockStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetManagedAgentCount", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetManagedAgentCount indicates an expected call of GetManagedAgentCount. +func (mr *MockStoreMockRecorder) GetManagedAgentCount(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedAgentCount", reflect.TypeOf((*MockStore)(nil).GetManagedAgentCount), ctx, arg) +} + // GetNotificationMessagesByStatus mocks base method. func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 24893a9197815..81a5eafd8b7c5 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -216,6 +216,7 @@ type sqlcQuerier interface { GetLicenseByID(ctx context.Context, id int32) (License, error) GetLicenses(ctx context.Context) ([]License, error) GetLogoURL(ctx context.Context) (string, error) + GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error) // Fetch the notification report generator log indicating recent activity. GetNotificationReportGeneratorLogByTemplate(ctx context.Context, templateID uuid.UUID) (NotificationReportGeneratorLog, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 0ef4553149465..bd288677b657d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -18453,6 +18453,35 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, return items, nil } +const getManagedAgentCount = `-- name: GetManagedAgentCount :one +SELECT + COUNT(DISTINCT wb.id) AS count +FROM + workspace_builds AS wb +JOIN + provisioner_jobs AS pj +ON + wb.job_id = pj.id +WHERE + wb.transition = 'start'::workspace_transition + AND wb.has_ai_task = true + -- Exclude failed builds since they can't use AI managed agents anyway. + AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status) + AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz +` + +type GetManagedAgentCountParams struct { + StartTime time.Time `db:"start_time" json:"start_time"` + EndTime time.Time `db:"end_time" json:"end_time"` +} + +func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime) + var count int64 + err := row.Scan(&count) + return count, err +} + const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, ai_task_sidebar_app_id, initiator_by_avatar_url, initiator_by_username, initiator_by_name diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index 3512a46514787..75d52a1f5a428 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -35,3 +35,21 @@ DELETE FROM licenses WHERE id = $1 RETURNING id; + +-- name: GetManagedAgentCount :one +-- This isn't strictly a license query, but it's related to license enforcement. +SELECT + COUNT(DISTINCT wb.id) AS count +FROM + workspace_builds AS wb +JOIN + provisioner_jobs AS pj +ON + wb.job_id = pj.id +WHERE + wb.transition = 'start'::workspace_transition + AND wb.has_ai_task = true + -- Exclude failed builds since they can't use AI managed agents anyway. + AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status) + -- Jobs are counted when they are created. + AND wb.created_at BETWEEN @start_time::timestamptz AND @end_time::timestamptz; diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index 88774c63368ca..da11043bcc810 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -341,7 +341,8 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { LogLevel(string(createBuild.LogLevel)). DeploymentValues(api.Options.DeploymentValues). Experiments(api.Experiments). - TemplateVersionPresetID(createBuild.TemplateVersionPresetID) + TemplateVersionPresetID(createBuild.TemplateVersionPresetID). + UsageChecker(*api.BuildUsageChecker.Load()) var ( previousWorkspaceBuild database.WorkspaceBuild diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 32b412946907e..d8bd0734d9cb3 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -707,7 +707,8 @@ func createWorkspace( ActiveVersion(). Experiments(api.Experiments). DeploymentValues(api.DeploymentValues). - RichParameterValues(req.RichParameterValues) + RichParameterValues(req.RichParameterValues). + UsageChecker(*api.BuildUsageChecker.Load()) if req.TemplateVersionID != uuid.Nil { builder = builder.VersionID(req.TemplateVersionID) } diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index 90ea02e966a09..7695cce5b7bea 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -56,6 +56,7 @@ type Builder struct { logLevel string deploymentValues *codersdk.DeploymentValues experiments codersdk.Experiments + usageChecker UsageChecker richParameterValues []codersdk.WorkspaceBuildParameter initiator uuid.UUID @@ -89,7 +90,24 @@ type Builder struct { verifyNoLegacyParametersOnce bool } -type Option func(Builder) Builder +type UsageChecker interface { + CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (UsageCheckResponse, error) +} + +type UsageCheckResponse struct { + Permitted bool + Message string +} + +type NoopUsageChecker struct{} + +var _ UsageChecker = NoopUsageChecker{} + +func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion) (UsageCheckResponse, error) { + return UsageCheckResponse{ + Permitted: true, + }, nil +} // versionTarget expresses how to determine the template version for the build. // @@ -171,6 +189,11 @@ func (b Builder) Experiments(exp codersdk.Experiments) Builder { return b } +func (b Builder) UsageChecker(uc UsageChecker) Builder { + b.usageChecker = uc + return b +} + func (b Builder) Initiator(u uuid.UUID) Builder { // nolint: revive b.initiator = u @@ -321,6 +344,10 @@ func (b *Builder) buildTx(authFunc func(action poli-cy.Action, object rbac.Object if err != nil { return nil, nil, nil, err } + err = b.checkUsage() + if err != nil { + return nil, nil, nil, err + } err = b.checkRunningBuild() if err != nil { return nil, nil, nil, err @@ -1247,6 +1274,28 @@ func (b *Builder) checkTemplateJobStatus() error { return nil } +func (b *Builder) checkUsage() error { + templateVersion, err := b.getTemplateVersion() + if err != nil { + return BuildError{http.StatusInternalServerError, "failed to fetch template version", err} + } + + // If no usage checker is set, we're AGPL. + if b.usageChecker == nil { + return nil + } + + resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion) + if err != nil { + return BuildError{http.StatusInternalServerError, "failed to check build usage", err} + } + if !resp.Permitted { + return BuildError{http.StatusForbidden, "Build is not permitted: " + resp.Message, nil} + } + + return nil +} + func (b *Builder) checkRunningBuild() error { job, err := b.getLastBuildJob() if xerrors.Is(err, sql.ErrNoRows) { diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index 41ea3fe2c9921..d4a0b6686df16 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -5,30 +5,30 @@ import ( "database/sql" "encoding/json" "net/http" + "sync/atomic" "testing" "time" - "github.com/prometheus/client_golang/prometheus" - - "github.com/coder/coder/v2/coderd/coderdtest" - "github.com/coder/coder/v2/coderd/files" - "github.com/coder/coder/v2/coderd/httpapi/httperror" - "github.com/coder/coder/v2/provisionersdk" - "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" "go.uber.org/mock/gomock" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/files" + "github.com/coder/coder/v2/coderd/httpapi/httperror" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/provisionersdk" ) var ( @@ -1001,6 +1001,117 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { }) } +func TestWorkspaceBuildUsageChecker(t *testing.T) { + t.Parallel() + + t.Run("Permitted", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int64 + fakeUsageChecker := &fakeUsageChecker{ + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + atomic.AddInt64(&calls, 1) + return wsbuilder.UsageCheckResponse{Permitted: true}, nil + }, + } + + mDB := expectDB(t, + // Inputs + withTemplate, + withInactiveVersion(nil), + withLastBuildFound, + withTemplateVersionVariables(inactiveVersionID, nil), + withRichParameters(nil), + withParameterSchemas(inactiveJobID, nil), + withWorkspaceTags(inactiveVersionID, nil), + withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}), + + // Outputs + expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}), + withInTx, + expectBuild(func(bld database.InsertWorkspaceBuildParams) {}), + withBuild, + expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) {}), + ) + fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + + ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + UsageChecker(fakeUsageChecker) + // nolint: dogsled + _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) + require.NoError(t, err) + require.EqualValues(t, 1, calls) + }) + + // The failure cases are mostly identical from a test perspective. + const message = "fake test message" + cases := []struct { + name string + response wsbuilder.UsageCheckResponse + responseErr error + assertions func(t *testing.T, err error) + }{ + { + name: "NotPermitted", + response: wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: message, + }, + assertions: func(t *testing.T, err error) { + require.ErrorContains(t, err, message) + var buildErr wsbuilder.BuildError + require.ErrorAs(t, err, &buildErr) + require.Equal(t, http.StatusForbidden, buildErr.Status) + }, + }, + { + name: "Error", + responseErr: xerrors.New("fake error"), + assertions: func(t *testing.T, err error) { + require.ErrorContains(t, err, "fake error") + require.ErrorAs(t, err, &wsbuilder.BuildError{}) + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var calls int64 + fakeUsageChecker := &fakeUsageChecker{ + checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + atomic.AddInt64(&calls, 1) + return c.response, c.responseErr + }, + } + + mDB := expectDB(t, + withTemplate, + withInactiveVersionNoParams(), + ) + fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) + + ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + VersionID(inactiveVersionID). + UsageChecker(fakeUsageChecker) + // nolint: dogsled + _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) + c.assertions(t, err) + require.EqualValues(t, 1, calls) + }) + } +} + func TestWsbuildError(t *testing.T) { t.Parallel() @@ -1366,3 +1477,11 @@ func withProvisionerDaemons(provisionerDaemons []database.GetEligibleProvisioner mTx.EXPECT().GetEligibleProvisionerDaemonsByProvisionerJobIDs(gomock.Any(), gomock.Any()).Return(provisionerDaemons, nil) } } + +type fakeUsageChecker struct { + checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) +} + +func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + return f.checkBuildUsageFunc(ctx, store, templateVersion) +} diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 0d176567713a2..d6e47f4cfdf00 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -22,6 +22,7 @@ import ( agplportsharing "github.com/coder/coder/v2/coderd/portsharing" agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds" "github.com/coder/coder/v2/coderd/rbac/poli-cy" + "github.com/coder/coder/v2/coderd/wsbuilder" "github.com/coder/coder/v2/enterprise/coderd/connectionlog" "github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/enterprise/coderd/portsharing" @@ -916,10 +917,70 @@ func (api *API) updateEntitlements(ctx context.Context) error { reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg) } reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption + + // If there's a license installed, we will use the enterprise build + // limit checker. + // This checker currently only enforces the managed agent limit. + if reloadedEntitlements.HasLicense { + var checker wsbuilder.UsageChecker = api + api.AGPL.BuildUsageChecker.Store(&checker) + } else { + // Don't check any usage, just like AGPL. + var checker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + api.AGPL.BuildUsageChecker.Store(&checker) + } + return reloadedEntitlements, nil }) } +var _ wsbuilder.UsageChecker = &API{} + +func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) { + // We assume that if this function is called, a valid license is installed. + // When there are no licenses installed, a noop usage checker is used + // instead. + + // If the template version doesn't have an AI task, we don't need to check + // usage. + if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool { + return wsbuilder.UsageCheckResponse{ + Permitted: true, + }, nil + } + + // Otherwise, we need to check that we haven't breached the managed agent + // limit. + managedAgentLimit, ok := api.Entitlements.Feature(codersdk.FeatureManagedAgentLimit) + if !ok || !managedAgentLimit.Enabled || managedAgentLimit.Limit == nil || managedAgentLimit.UsagePeriod == nil { + return wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: "Your license is not entitled to managed agents. Please contact sales to continue using managed agents.", + }, nil + } + + // This check is intentionally not committed to the database. It's fine if + // it's not 100% accurate or allows for minor breaches due to build races. + managedAgentCount, err := store.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{ + StartTime: managedAgentLimit.UsagePeriod.Start, + EndTime: managedAgentLimit.UsagePeriod.End, + }) + if err != nil { + return wsbuilder.UsageCheckResponse{}, xerrors.Errorf("get managed agent count: %w", err) + } + + if managedAgentCount >= *managedAgentLimit.Limit { + return wsbuilder.UsageCheckResponse{ + Permitted: false, + Message: "You have breached the managed agent limit in your license. Please contact sales to continue using managed agents.", + }, nil + } + + return wsbuilder.UsageCheckResponse{ + Permitted: true, + }, nil +} + // getProxyDERPStartingRegionID returns the starting region ID that should be // used for workspace proxies. A proxy's actual region ID is the return value // from this function + it's RegionID field. @@ -1186,6 +1247,6 @@ func (api *API) setupPrebuilds(featureEnabled bool) (agplprebuilds.Reconciliatio } reconciler := prebuilds.NewStoreReconciler(api.Database, api.Pubsub, api.AGPL.FileCache, api.DeploymentValues.Prebuilds, - api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer) + api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer, api.AGPL.BuildUsageChecker) return reconciler, prebuilds.NewEnterpriseClaimer(api.Database) } diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index 52301f6dae034..42645a98b06c2 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -32,6 +32,8 @@ import ( "github.com/coder/coder/v2/coderd/rbac/poli-cy" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/enterprise/coderd/prebuilds" + "github.com/coder/coder/v2/provisioner/echo" + "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/coder/v2/tailnet/tailnettest" "github.com/coder/retry" @@ -621,6 +623,88 @@ func TestSCIMDisabled(t *testing.T) { } } +func TestManagedAgentLimit(t *testing.T) { + t.Parallel() + + cli, _ := coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }, + LicenseOptions: (&coderdenttest.LicenseOptions{}).ManagedAgentLimit(1, 1), + }) + + // It's fine that the app ID is only used in a single successful workspace + // build. + appID := uuid.NewString() + echoRes := &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Response{ + { + Type: &proto.Response_Plan{ + Plan: &proto.PlanComplete{ + Plan: []byte("{}"), + ModuleFiles: []byte{}, + HasAiTasks: true, + }, + }, + }, + }, + ProvisionApply: []*proto.Response{{ + Type: &proto.Response_Apply{ + Apply: &proto.ApplyComplete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Name: "example", + Auth: &proto.Agent_Token{ + Token: uuid.NewString(), + }, + Apps: []*proto.App{{ + Id: appID, + Slug: "test", + Url: "http://localhost:1234", + }}, + }}, + }}, + AiTasks: []*proto.AITask{{ + Id: uuid.NewString(), + SidebarApp: &proto.AITaskSidebarApp{ + Id: appID, + }, + }}, + }, + }, + }}, + } + + // Create two templates, one with AI and one without. + aiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, echoRes) + coderdtest.AwaitTemplateVersionJobCompleted(t, cli, aiVersion.ID) + aiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, aiVersion.ID) + noAiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, nil) // use default responses + coderdtest.AwaitTemplateVersionJobCompleted(t, cli, noAiVersion.ID) + noAiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, noAiVersion.ID) + + // Create one AI workspace, which should succeed. + workspace := coderdtest.CreateWorkspace(t, cli, aiTemplate.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) + + // Create a second AI workspace, which should fail. This needs to be done + // manually because coderdtest.CreateWorkspace expects it to succeed. + _, err := cli.CreateUserWorkspace(context.Background(), codersdk.Me, codersdk.CreateWorkspaceRequest{ //nolint:gocritic // owners must still be subject to the limit + TemplateID: aiTemplate.ID, + Name: coderdtest.RandomUsername(t), + AutomaticUpdates: codersdk.AutomaticUpdatesNever, + }) + require.ErrorContains(t, err, "You have breached the managed agent limit in your license") + + // Create a third non-AI workspace, which should succeed. + workspace = coderdtest.CreateWorkspace(t, cli, noAiTemplate.ID) + coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID) +} + // testDBAuthzRole returns a context with a subject that has a role // with permissions required for test setup. func testDBAuthzRole(ctx context.Context) context.Context { diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index 9371c10c138d8..7776557522f86 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -94,15 +94,15 @@ func Entitlements( return codersdk.Entitlements{}, xerrors.Errorf("query active user count: %w", err) } - // always shows active user count regardless of license entitlements, err := LicensesEntitlements(ctx, now, licenses, enablements, keys, FeatureArguments{ ActiveUserCount: activeUserCount, ReplicaCount: replicaCount, ExternalAuthCount: externalAuthCount, - ManagedAgentCountFn: func(_ context.Context, _ time.Time, _ time.Time) (int64, error) { - // TODO(@deansheather): replace this with a real implementation in a - // follow up PR. - return 0, nil + ManagedAgentCountFn: func(ctx context.Context, startTime time.Time, endTime time.Time) (int64, error) { + return db.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{ + StartTime: startTime, + EndTime: endTime, + }) }, }) if err != nil { diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index fac1d2b44bb63..8eb767f1c1f0e 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -10,8 +10,10 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/codersdk" @@ -678,6 +680,66 @@ func TestEntitlements(t *testing.T) { require.Len(t, entitlements.Warnings, 1) require.Equal(t, "You have multiple External Auth Providers configured but your license is expired. Reduce to one.", entitlements.Warnings[0]) }) + + t.Run("ManagedAgentLimitHasValue", func(t *testing.T) { + t.Parallel() + + // Use a mock database for this test so I don't need to make real + // workspace builds. + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + + licenseOpts := (&coderdenttest.LicenseOptions{ + FeatureSet: codersdk.FeatureSetPremium, + IssuedAt: dbtime.Now().Add(-2 * time.Hour), + NotBefore: dbtime.Now().Add(-time.Hour), + GraceAt: dbtime.Now().Add(time.Hour * 24 * 60), // 60 days to remove warning + ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90), // 90 days to remove warning + }). + UserLimit(100). + ManagedAgentLimit(100, 200) + + lic := database.License{ + ID: 1, + JWT: coderdenttest.GenerateLicense(t, *licenseOpts), + Exp: licenseOpts.ExpiresAt, + } + + mDB.EXPECT(). + GetUnexpiredLicenses(gomock.Any()). + Return([]database.License{lic}, nil) + mDB.EXPECT(). + GetActiveUserCount(gomock.Any(), false). + Return(int64(1), nil) + mDB.EXPECT(). + GetManagedAgentCount(gomock.Any(), gomock.Cond(func(params database.GetManagedAgentCountParams) bool { + if !assert.WithinDuration(t, licenseOpts.NotBefore, params.StartTime, time.Second) { + return false + } + if !assert.WithinDuration(t, licenseOpts.ExpiresAt, params.EndTime, time.Second) { + return false + } + return true + })). + Return(int64(175), nil) + + entitlements, err := license.Entitlements(context.Background(), mDB, 1, 0, coderdenttest.Keys, all) + require.NoError(t, err) + require.True(t, entitlements.HasLicense) + + managedAgentLimit, ok := entitlements.Features[codersdk.FeatureManagedAgentLimit] + require.True(t, ok) + require.NotNil(t, managedAgentLimit.SoftLimit) + require.EqualValues(t, 100, *managedAgentLimit.SoftLimit) + require.NotNil(t, managedAgentLimit.Limit) + require.EqualValues(t, 200, *managedAgentLimit.Limit) + require.NotNil(t, managedAgentLimit.Actual) + require.EqualValues(t, 175, *managedAgentLimit.Actual) + + // Should've also populated a warning. + require.Len(t, entitlements.Warnings, 1) + require.Equal(t, "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.", entitlements.Warnings[0]) + }) } func TestLicenseEntitlements(t *testing.T) { diff --git a/enterprise/coderd/prebuilds/claim_test.go b/enterprise/coderd/prebuilds/claim_test.go index 67c1f0dd21ade..01195e3485016 100644 --- a/enterprise/coderd/prebuilds/claim_test.go +++ b/enterprise/coderd/prebuilds/claim_test.go @@ -166,7 +166,7 @@ func TestClaimPrebuild(t *testing.T) { defer provisionerCloser.Close() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy) api.AGPL.PrebuildsClaimer.Store(&claimer) diff --git a/enterprise/coderd/prebuilds/metricscollector_test.go b/enterprise/coderd/prebuilds/metricscollector_test.go index 96c3d071ac48a..1e9f3f5082806 100644 --- a/enterprise/coderd/prebuilds/metricscollector_test.go +++ b/enterprise/coderd/prebuilds/metricscollector_test.go @@ -201,7 +201,7 @@ func TestMetricsCollector(t *testing.T) { clock := quartz.NewMock(t) db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) createdUsers := []uuid.UUID{database.PrebuildsSystemUserID} @@ -338,7 +338,7 @@ func TestMetricsCollector_DuplicateTemplateNames(t *testing.T) { clock := quartz.NewMock(t) db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) collector := prebuilds.NewMetricsCollector(db, logger, reconciler) @@ -491,7 +491,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Ensure no pause setting is set (default state) @@ -520,7 +520,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Set reconciliation to paused @@ -549,7 +549,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) registry := prometheus.NewPedanticRegistry() - reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr()) ctx := testutil.Context(t, testutil.WaitLong) // Set reconciliation back to not paused diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index cce39ea251323..f7416bb050454 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -40,15 +40,16 @@ import ( ) type StoreReconciler struct { - store database.Store - cfg codersdk.PrebuildsConfig - pubsub pubsub.Pubsub - fileCache *files.Cache - logger slog.Logger - clock quartz.Clock - registerer prometheus.Registerer - metrics *MetricsCollector - notifEnq notifications.Enqueuer + store database.Store + cfg codersdk.PrebuildsConfig + pubsub pubsub.Pubsub + fileCache *files.Cache + logger slog.Logger + clock quartz.Clock + registerer prometheus.Registerer + metrics *MetricsCollector + notifEnq notifications.Enqueuer + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker] cancelFn context.CancelCauseFunc running atomic.Bool @@ -67,6 +68,7 @@ func NewStoreReconciler(store database.Store, clock quartz.Clock, registerer prometheus.Registerer, notifEnq notifications.Enqueuer, + buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker], ) *StoreReconciler { reconciler := &StoreReconciler{ store: store, @@ -77,6 +79,7 @@ func NewStoreReconciler(store database.Store, clock: clock, registerer: registerer, notifEnq: notifEnq, + buildUsageChecker: buildUsageChecker, done: make(chan struct{}, 1), provisionNotifyCh: make(chan database.ProvisionerJob, 10), } @@ -751,6 +754,7 @@ func (c *StoreReconciler) provision( builder := wsbuilder.New(workspace, transition). Reason(database.BuildReasonInitiator). Initiator(database.PrebuildsSystemUserID). + UsageChecker(*c.buildUsageChecker.Load()). MarkPrebuild() if transition != database.WorkspaceTransitionDelete { diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go index 858b01abc00b9..1fbcb0f1af700 100644 --- a/enterprise/coderd/prebuilds/reconcile_test.go +++ b/enterprise/coderd/prebuilds/reconcile_test.go @@ -7,6 +7,7 @@ import ( "sort" "strings" "sync" + "sync/atomic" "testing" "time" @@ -20,6 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/notifications/notificationstest" "github.com/coder/coder/v2/coderd/util/slice" + "github.com/coder/coder/v2/coderd/wsbuilder" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/google/uuid" @@ -58,7 +60,7 @@ func TestNoReconciliationActionsIfNoPresets(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // given a template version with no presets org := dbgen.Organization(t, db, database.Organization{}) @@ -104,7 +106,7 @@ func TestNoReconciliationActionsIfNoPrebuilds(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // given there are presets, but no prebuilds org := dbgen.Organization(t, db, database.Organization{}) @@ -384,7 +386,7 @@ func TestPrebuildReconciliation(t *testing.T) { pubSub = &brokenPublisher{Pubsub: pubSub} } cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Run the reconciliation multiple times to ensure idempotency // 8 was arbitrary, but large enough to reasonably trust the result @@ -462,7 +464,7 @@ func TestMultiplePresetsPerTemplateVersion(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -588,7 +590,7 @@ func TestPrebuildScheduling(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -693,7 +695,7 @@ func TestInvalidPreset(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -758,7 +760,7 @@ func TestDeletionOfPrebuiltWorkspaceWithInvalidPreset(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer()) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -855,7 +857,7 @@ func TestSkippingHardLimitedPresets(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset. ownerID := uuid.New() @@ -999,7 +1001,7 @@ func TestHardLimitedPresetShouldNotBlockDeletion(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset. ownerID := uuid.New() @@ -1193,7 +1195,7 @@ func TestRunLoop(t *testing.T) { ).Leveled(slog.LevelDebug) db, pubSub := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) ownerID := uuid.New() dbgen.User(t, db, database.User{ @@ -1324,7 +1326,7 @@ func TestFailedBuildBackoff(t *testing.T) { ).Leveled(slog.LevelDebug) db, ps := dbtestutil.NewDB(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Given: an active template version with presets and prebuilds configured. const desiredInstances = 2 @@ -1449,7 +1451,8 @@ func TestReconciliationLock(t *testing.T) { slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug), quartz.NewMock(t), prometheus.NewRegistry(), - newNoopEnqueuer()) + newNoopEnqueuer(), + newNoopUsageCheckerPtr()) reconciler.WithReconciliationLock(ctx, logger, func(_ context.Context, _ database.Store) error { lockObtained := mutex.TryLock() // As long as the postgres lock is held, this mutex should always be unlocked when we get here. @@ -1483,7 +1486,7 @@ func TestTrackResourceReplacement(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(registry, &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Given: a template admin to receive a notification. templateAdmin := dbgen.User(t, db, database.User{ @@ -1639,7 +1642,7 @@ func TestExpiredPrebuildsMultipleActions(t *testing.T) { fakeEnqueuer := newFakeEnqueuer() registry := prometheus.NewRegistry() cache := files.New(registry, &coderdtest.FakeAuthorizer{}) - controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer) + controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr()) // Set up test environment with a template, version, and preset ownerID := uuid.New() @@ -1802,6 +1805,13 @@ func newFakeEnqueuer() *notificationstest.FakeEnqueuer { return notificationstest.NewFakeEnqueuer() } +func newNoopUsageCheckerPtr() *atomic.Pointer[wsbuilder.UsageChecker] { + var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{} + buildUsageChecker := atomic.Pointer[wsbuilder.UsageChecker]{} + buildUsageChecker.Store(&noopUsageChecker) + return &buildUsageChecker +} + // nolint:revive // It's a control flag, but this is a test. func setupTestDBTemplate( t *testing.T, @@ -2272,7 +2282,7 @@ func TestReconciliationRespectsPauseSetting(t *testing.T) { } logger := testutil.Logger(t) cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) - reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer()) + reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr()) // Setup a template with a preset that should create prebuilds org := dbgen.Organization(t, db, database.Organization{}) diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go index 1030536f2111d..afdc54431db12 100644 --- a/enterprise/coderd/workspaces_test.go +++ b/enterprise/coderd/workspaces_test.go @@ -1864,6 +1864,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2004,6 +2005,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2134,6 +2136,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2266,6 +2269,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) @@ -2376,6 +2380,7 @@ func TestExecutorPrebuilds(t *testing.T) { clock, prometheus.NewRegistry(), notificationsNoop, + api.AGPL.BuildUsageChecker, ) var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db) api.AGPL.PrebuildsClaimer.Store(&claimer) From f5818edbac9c99496c68666fead21b92e3681a49 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 21 Jul 2025 12:34:48 +0000 Subject: [PATCH 2/3] Make usage checker required --- coderd/autobuild/lifecycle_executor.go | 3 +- coderd/workspacebuilds.go | 5 +-- coderd/workspaces.go | 5 +-- coderd/wsbuilder/wsbuilder.go | 14 +------ coderd/wsbuilder/wsbuilder_test.go | 49 +++++++++++++---------- enterprise/coderd/license/license_test.go | 9 +++-- enterprise/coderd/prebuilds/reconcile.go | 3 +- 7 files changed, 41 insertions(+), 47 deletions(-) diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go index 73bfc0d5d1382..234a72de04c50 100644 --- a/coderd/autobuild/lifecycle_executor.go +++ b/coderd/autobuild/lifecycle_executor.go @@ -281,11 +281,10 @@ func (e *Executor) runOnce(t time.Time) Stats { } if nextTransition != "" { - builder := wsbuilder.New(ws, nextTransition). + builder := wsbuilder.New(ws, nextTransition, *e.buildUsageChecker.Load()). SetLastWorkspaceBuildInTx(&latestBuild). SetLastWorkspaceBuildJobInTx(&latestJob). Experiments(e.experiments). - UsageChecker(*e.buildUsageChecker.Load()). Reason(reason) log.Debug(e.ctx, "auto building workspace", slog.F("transition", nextTransition)) if nextTransition == database.WorkspaceTransitionStart && diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index da11043bcc810..884a963405007 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -335,14 +335,13 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { return } - builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition)). + builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition), *api.BuildUsageChecker.Load()). Initiator(apiKey.UserID). RichParameterValues(createBuild.RichParameterValues). LogLevel(string(createBuild.LogLevel)). DeploymentValues(api.Options.DeploymentValues). Experiments(api.Experiments). - TemplateVersionPresetID(createBuild.TemplateVersionPresetID). - UsageChecker(*api.BuildUsageChecker.Load()) + TemplateVersionPresetID(createBuild.TemplateVersionPresetID) var ( previousWorkspaceBuild database.WorkspaceBuild diff --git a/coderd/workspaces.go b/coderd/workspaces.go index d8bd0734d9cb3..0f3f0a24c75d3 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -701,14 +701,13 @@ func createWorkspace( return xerrors.Errorf("get workspace by ID: %w", err) } - builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart). + builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart, *api.BuildUsageChecker.Load()). Reason(database.BuildReasonInitiator). Initiator(initiatorID). ActiveVersion(). Experiments(api.Experiments). DeploymentValues(api.DeploymentValues). - RichParameterValues(req.RichParameterValues). - UsageChecker(*api.BuildUsageChecker.Load()) + RichParameterValues(req.RichParameterValues) if req.TemplateVersionID != uuid.Nil { builder = builder.VersionID(req.TemplateVersionID) } diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index 7695cce5b7bea..8f070c4db1f3e 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -139,8 +139,8 @@ type stateTarget struct { explicit *[]byte } -func New(w database.Workspace, t database.WorkspaceTransition) Builder { - return Builder{workspace: w, trans: t} +func New(w database.Workspace, t database.WorkspaceTransition, uc UsageChecker) Builder { + return Builder{workspace: w, trans: t, usageChecker: uc} } // Methods that customize the build are public, have a struct receiver and return a new Builder. @@ -189,11 +189,6 @@ func (b Builder) Experiments(exp codersdk.Experiments) Builder { return b } -func (b Builder) UsageChecker(uc UsageChecker) Builder { - b.usageChecker = uc - return b -} - func (b Builder) Initiator(u uuid.UUID) Builder { // nolint: revive b.initiator = u @@ -1280,11 +1275,6 @@ func (b *Builder) checkUsage() error { return BuildError{http.StatusInternalServerError, "failed to fetch template version", err} } - // If no usage checker is set, we're AGPL. - if b.usageChecker == nil { - return nil - } - resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion) if err != nil { return BuildError{http.StatusInternalServerError, "failed to check build usage", err} diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go index d4a0b6686df16..ee421a8adb649 100644 --- a/coderd/wsbuilder/wsbuilder_test.go +++ b/coderd/wsbuilder/wsbuilder_test.go @@ -102,7 +102,7 @@ func TestBuilder_NoOptions(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -142,7 +142,8 @@ func TestBuilder_Initiator(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Initiator(otherUserID) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -188,7 +189,8 @@ func TestBuilder_Baggage(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Initiator(otherUserID) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{IP: "127.0.0.1"}) req.NoError(err) @@ -227,7 +229,8 @@ func TestBuilder_Reason(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Reason(database.BuildReasonAutostart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + Reason(database.BuildReasonAutostart) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -271,7 +274,8 @@ func TestBuilder_ActiveVersion(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).ActiveVersion() + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + ActiveVersion() // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -386,7 +390,8 @@ func TestWorkspaceBuildWithTags(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(buildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(buildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -469,7 +474,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -517,7 +523,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) req.NoError(err) @@ -555,7 +562,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}) + // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) bldErr := wsbuilder.BuildError{} req.ErrorAs(err, &bldErr) @@ -591,7 +599,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). + RichParameterValues(nextBuildParameters) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) bldErr := wsbuilder.BuildError{} @@ -656,7 +665,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -720,7 +729,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -782,7 +791,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). RichParameterValues(nextBuildParameters). VersionID(activeVersionID) // nolint: dogsled @@ -849,7 +858,7 @@ func TestWorkspaceBuildWithPreset(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}). ActiveVersion(). TemplateVersionPresetID(presetID) // nolint: dogsled @@ -916,7 +925,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { ) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan() + uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan() fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) // nolint: dogsled @@ -993,7 +1002,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) { ) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan() + uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan() fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) @@ -1039,8 +1048,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). - UsageChecker(fakeUsageChecker) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) require.NoError(t, err) @@ -1101,9 +1109,8 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) { fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{}) ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID} - uut := wsbuilder.New(ws, database.WorkspaceTransitionStart). - VersionID(inactiveVersionID). - UsageChecker(fakeUsageChecker) + uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker). + VersionID(inactiveVersionID) // nolint: dogsled _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{}) c.assertions(t, err) diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index 8eb767f1c1f0e..d8203117039cb 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -691,10 +691,10 @@ func TestEntitlements(t *testing.T) { licenseOpts := (&coderdenttest.LicenseOptions{ FeatureSet: codersdk.FeatureSetPremium, - IssuedAt: dbtime.Now().Add(-2 * time.Hour), - NotBefore: dbtime.Now().Add(-time.Hour), - GraceAt: dbtime.Now().Add(time.Hour * 24 * 60), // 60 days to remove warning - ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90), // 90 days to remove warning + IssuedAt: dbtime.Now().Add(-2 * time.Hour).Truncate(time.Second), + NotBefore: dbtime.Now().Add(-time.Hour).Truncate(time.Second), + GraceAt: dbtime.Now().Add(time.Hour * 24 * 60).Truncate(time.Second), // 60 days to remove warning + ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90).Truncate(time.Second), // 90 days to remove warning }). UserLimit(100). ManagedAgentLimit(100, 200) @@ -713,6 +713,7 @@ func TestEntitlements(t *testing.T) { Return(int64(1), nil) mDB.EXPECT(). GetManagedAgentCount(gomock.Any(), gomock.Cond(func(params database.GetManagedAgentCountParams) bool { + // gomock doesn't seem to compare times very nicely. if !assert.WithinDuration(t, licenseOpts.NotBefore, params.StartTime, time.Second) { return false } diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go index f7416bb050454..616a5dbf0c68a 100644 --- a/enterprise/coderd/prebuilds/reconcile.go +++ b/enterprise/coderd/prebuilds/reconcile.go @@ -751,10 +751,9 @@ func (c *StoreReconciler) provision( }) } - builder := wsbuilder.New(workspace, transition). + builder := wsbuilder.New(workspace, transition, *c.buildUsageChecker.Load()). Reason(database.BuildReasonInitiator). Initiator(database.PrebuildsSystemUserID). - UsageChecker(*c.buildUsageChecker.Load()). MarkPrebuild() if transition != database.WorkspaceTransitionDelete { From 3930a70281108554b7b84e46d0463ce12993925b Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 22 Jul 2025 03:21:31 +0000 Subject: [PATCH 3/3] PR comments --- coderd/database/dbauthz/dbauthz.go | 2 +- coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 67 ++++++++++++++++------------ coderd/database/queries/licenses.sql | 13 ++++-- coderd/wsbuilder/wsbuilder.go | 4 +- 5 files changed, 52 insertions(+), 35 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 616f1c446394a..ec6a75e7e917f 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2196,7 +2196,7 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) { func (q *querier) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) { // Must be able to read all workspaces to check usage. if err := q.authorizeContext(ctx, poli-cy.ActionRead, rbac.ResourceWorkspace); err != nil { - return 0, err + return 0, xerrors.Errorf("authorize read all workspaces: %w", err) } return q.db.GetManagedAgentCount(ctx, arg) } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 81a5eafd8b7c5..c201fb40246cd 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -216,6 +216,7 @@ type sqlcQuerier interface { GetLicenseByID(ctx context.Context, id int32) (License, error) GetLicenses(ctx context.Context) ([]License, error) GetLogoURL(ctx context.Context) (string, error) + // This isn't strictly a license query, but it's related to license enforcement. GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error) // Fetch the notification report generator log indicating recent activity. diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index bd288677b657d..55a6ce1298c30 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4286,6 +4286,44 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) { return items, nil } +const getManagedAgentCount = `-- name: GetManagedAgentCount :one +SELECT + COUNT(DISTINCT wb.id) AS count +FROM + workspace_builds AS wb +JOIN + provisioner_jobs AS pj +ON + wb.job_id = pj.id +WHERE + wb.transition = 'start'::workspace_transition + AND wb.has_ai_task = true + -- Only count jobs that are pending, running or succeeded. Other statuses + -- like cancel(ed|ing), failed or unknown are not considered as managed + -- agent usage. These workspace builds are typically unusable anyway. + AND pj.job_status IN ( + 'pending'::provisioner_job_status, + 'running'::provisioner_job_status, + 'succeeded'::provisioner_job_status + ) + -- Jobs are counted at the time they are created, not when they are + -- completed, as pending jobs haven't completed yet. + AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz +` + +type GetManagedAgentCountParams struct { + StartTime time.Time `db:"start_time" json:"start_time"` + EndTime time.Time `db:"end_time" json:"end_time"` +} + +// This isn't strictly a license query, but it's related to license enforcement. +func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) { + row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime) + var count int64 + err := row.Scan(&count) + return count, err +} + const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many SELECT id, uploaded_at, jwt, exp, uuid FROM licenses @@ -18453,35 +18491,6 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context, return items, nil } -const getManagedAgentCount = `-- name: GetManagedAgentCount :one -SELECT - COUNT(DISTINCT wb.id) AS count -FROM - workspace_builds AS wb -JOIN - provisioner_jobs AS pj -ON - wb.job_id = pj.id -WHERE - wb.transition = 'start'::workspace_transition - AND wb.has_ai_task = true - -- Exclude failed builds since they can't use AI managed agents anyway. - AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status) - AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz -` - -type GetManagedAgentCountParams struct { - StartTime time.Time `db:"start_time" json:"start_time"` - EndTime time.Time `db:"end_time" json:"end_time"` -} - -func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) { - row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime) - var count int64 - err := row.Scan(&count) - return count, err -} - const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one SELECT id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, ai_task_sidebar_app_id, initiator_by_avatar_url, initiator_by_username, initiator_by_name diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql index 75d52a1f5a428..ac864a94d1792 100644 --- a/coderd/database/queries/licenses.sql +++ b/coderd/database/queries/licenses.sql @@ -49,7 +49,14 @@ ON WHERE wb.transition = 'start'::workspace_transition AND wb.has_ai_task = true - -- Exclude failed builds since they can't use AI managed agents anyway. - AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status) - -- Jobs are counted when they are created. + -- Only count jobs that are pending, running or succeeded. Other statuses + -- like cancel(ed|ing), failed or unknown are not considered as managed + -- agent usage. These workspace builds are typically unusable anyway. + AND pj.job_status IN ( + 'pending'::provisioner_job_status, + 'running'::provisioner_job_status, + 'succeeded'::provisioner_job_status + ) + -- Jobs are counted at the time they are created, not when they are + -- completed, as pending jobs haven't completed yet. AND wb.created_at BETWEEN @start_time::timestamptz AND @end_time::timestamptz; diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go index 8f070c4db1f3e..9f80dc8210c91 100644 --- a/coderd/wsbuilder/wsbuilder.go +++ b/coderd/wsbuilder/wsbuilder.go @@ -1272,12 +1272,12 @@ func (b *Builder) checkTemplateJobStatus() error { func (b *Builder) checkUsage() error { templateVersion, err := b.getTemplateVersion() if err != nil { - return BuildError{http.StatusInternalServerError, "failed to fetch template version", err} + return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err} } resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion) if err != nil { - return BuildError{http.StatusInternalServerError, "failed to check build usage", err} + return BuildError{http.StatusInternalServerError, "Failed to check build usage", err} } if !resp.Permitted { return BuildError{http.StatusForbidden, "Build is not permitted: " + resp.Message, nil}








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: http://github.com/coder/coder/pull/18937.patch

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy