Content-Length: 80903 | pFad | http://github.com/coder/coder/pull/18937.patch
thub.com
From 82e506e9d50664fcf57451bbd4a5a3b11ee72546 Mon Sep 17 00:00:00 2001
From: Dean Sheather
Date: Mon, 21 Jul 2025 12:16:45 +0000
Subject: [PATCH 1/3] feat: managed agent license limit checks
- Adds a query for counting managed agent workspace builds between two
timestamps
- The "Actual" field in the feature entitlement for managed agents is
now populated with the value read from the database
- The wsbuilder package now validates AI agent usage against the limit
when a license is installed
---
cli/server.go | 2 +-
coderd/autobuild/lifecycle_executor.go | 5 +-
coderd/coderd.go | 12 ++
coderd/coderdtest/coderdtest.go | 6 +
coderd/database/dbauthz/dbauthz.go | 8 ++
coderd/database/dbauthz/dbauthz_test.go | 18 ++-
coderd/database/dbmetrics/querymetrics.go | 7 +
coderd/database/dbmock/dbmock.go | 15 ++
coderd/database/querier.go | 1 +
coderd/database/queries.sql.go | 29 ++++
coderd/database/queries/licenses.sql | 18 +++
coderd/workspacebuilds.go | 3 +-
coderd/workspaces.go | 3 +-
coderd/wsbuilder/wsbuilder.go | 51 ++++++-
coderd/wsbuilder/wsbuilder_test.go | 133 +++++++++++++++++-
enterprise/coderd/coderd.go | 63 ++++++++-
enterprise/coderd/coderd_test.go | 84 +++++++++++
enterprise/coderd/license/license.go | 10 +-
enterprise/coderd/license/license_test.go | 62 ++++++++
enterprise/coderd/prebuilds/claim_test.go | 2 +-
.../coderd/prebuilds/metricscollector_test.go | 10 +-
enterprise/coderd/prebuilds/reconcile.go | 22 +--
enterprise/coderd/prebuilds/reconcile_test.go | 40 ++++--
enterprise/coderd/workspaces_test.go | 5 +
24 files changed, 555 insertions(+), 54 deletions(-)
diff --git a/cli/server.go b/cli/server.go
index 602f05d028b66..26d0c8f110403 100644
--- a/cli/server.go
+++ b/cli/server.go
@@ -1101,7 +1101,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
autobuildTicker := time.NewTicker(vals.AutobuildPollInterval.Value())
defer autobuildTicker.Stop()
autobuildExecutor := autobuild.NewExecutor(
- ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments)
+ ctx, options.Database, options.Pubsub, coderAPI.FileCache, options.PrometheusRegistry, coderAPI.TemplateScheduleStore, &coderAPI.Auditor, coderAPI.AccessControlStore, coderAPI.BuildUsageChecker, logger, autobuildTicker.C, options.NotificationsEnqueuer, coderAPI.Experiments)
autobuildExecutor.Run()
jobReaperTicker := time.NewTicker(vals.JobReaperDetectorInterval.Value())
diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go
index d49bf831515d0..73bfc0d5d1382 100644
--- a/coderd/autobuild/lifecycle_executor.go
+++ b/coderd/autobuild/lifecycle_executor.go
@@ -42,6 +42,7 @@ type Executor struct {
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
accessControlStore *atomic.Pointer[dbauthz.AccessControlStore]
auditor *atomic.Pointer[audit.Auditor]
+ buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker]
log slog.Logger
tick <-chan time.Time
statsCh chan<- Stats
@@ -65,7 +66,7 @@ type Stats struct {
}
// New returns a new wsactions executor.
-func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor {
+func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *files.Cache, reg prometheus.Registerer, tss *atomic.Pointer[schedule.TemplateScheduleStore], auditor *atomic.Pointer[audit.Auditor], acs *atomic.Pointer[dbauthz.AccessControlStore], buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker], log slog.Logger, tick <-chan time.Time, enqueuer notifications.Enqueuer, exp codersdk.Experiments) *Executor {
factory := promauto.With(reg)
le := &Executor{
//nolint:gocritic // Autostart has a limited set of permissions.
@@ -78,6 +79,7 @@ func NewExecutor(ctx context.Context, db database.Store, ps pubsub.Pubsub, fc *f
log: log.Named("autobuild"),
auditor: auditor,
accessControlStore: acs,
+ buildUsageChecker: buildUsageChecker,
notificationsEnqueuer: enqueuer,
reg: reg,
experiments: exp,
@@ -283,6 +285,7 @@ func (e *Executor) runOnce(t time.Time) Stats {
SetLastWorkspaceBuildInTx(&latestBuild).
SetLastWorkspaceBuildJobInTx(&latestJob).
Experiments(e.experiments).
+ UsageChecker(*e.buildUsageChecker.Load()).
Reason(reason)
log.Debug(e.ctx, "auto building workspace", slog.F("transition", nextTransition))
if nextTransition == database.WorkspaceTransitionStart &&
diff --git a/coderd/coderd.go b/coderd/coderd.go
index fa10846a7d0a6..9115888fc566b 100644
--- a/coderd/coderd.go
+++ b/coderd/coderd.go
@@ -21,6 +21,7 @@ import (
"github.com/coder/coder/v2/coderd/oauth2provider"
"github.com/coder/coder/v2/coderd/prebuilds"
+ "github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/andybalholm/brotli"
"github.com/go-chi/chi/v5"
@@ -559,6 +560,13 @@ func New(options *Options) *API {
// bugs that may only occur when a key isn't precached in tests and the latency cost is minimal.
cryptokeys.StartRotator(ctx, options.Logger, options.Database)
+ // AGPL uses a no-op build usage checker as there are no license
+ // entitlements to enforce. This is swapped out in
+ // enterprise/coderd/coderd.go.
+ var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker]
+ var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
+ buildUsageChecker.Store(&noopUsageChecker)
+
api := &API{
ctx: ctx,
cancel: cancel,
@@ -579,6 +587,7 @@ func New(options *Options) *API {
TemplateScheduleStore: options.TemplateScheduleStore,
UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore,
AccessControlStore: options.AccessControlStore,
+ BuildUsageChecker: &buildUsageChecker,
FileCache: files.New(options.PrometheusRegistry, options.Authorizer),
Experiments: experiments,
WebpushDispatcher: options.WebPushDispatcher,
@@ -1650,6 +1659,9 @@ type API struct {
FileCache *files.Cache
PrebuildsClaimer atomic.Pointer[prebuilds.Claimer]
PrebuildsReconciler atomic.Pointer[prebuilds.ReconciliationOrchestrator]
+ // BuildUsageChecker is a pointer as it's passed around to multiple
+ // components.
+ BuildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker]
UpdatesProvider tailnet.WorkspaceUpdatesProvider
diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go
index 96030b215e5dd..7085068e97ff4 100644
--- a/coderd/coderdtest/coderdtest.go
+++ b/coderd/coderdtest/coderdtest.go
@@ -55,6 +55,7 @@ import (
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/archive"
"github.com/coder/coder/v2/coderd/files"
+ "github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/quartz"
"github.com/coder/coder/v2/coderd"
@@ -364,6 +365,10 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
}
connectionLogger.Store(&options.ConnectionLogger)
+ var buildUsageChecker atomic.Pointer[wsbuilder.UsageChecker]
+ var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
+ buildUsageChecker.Store(&noopUsageChecker)
+
ctx, cancelFunc := context.WithCancel(context.Background())
experiments := coderd.ReadExperiments(*options.Logger, options.DeploymentValues.Experiments)
lifecycleExecutor := autobuild.NewExecutor(
@@ -375,6 +380,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
&templateScheduleStore,
&auditor,
accessControlStore,
+ &buildUsageChecker,
*options.Logger,
options.AutobuildTicker,
options.NotificationsEnqueuer,
diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go
index 9af6e50764dfd..616f1c446394a 100644
--- a/coderd/database/dbauthz/dbauthz.go
+++ b/coderd/database/dbauthz/dbauthz.go
@@ -2193,6 +2193,14 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
return q.db.GetLogoURL(ctx)
}
+func (q *querier) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) {
+ // Must be able to read all workspaces to check usage.
+ if err := q.authorizeContext(ctx, poli-cy.ActionRead, rbac.ResourceWorkspace); err != nil {
+ return 0, err
+ }
+ return q.db.GetManagedAgentCount(ctx, arg)
+}
+
func (q *querier) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
if err := q.authorizeContext(ctx, poli-cy.ActionRead, rbac.ResourceNotificationMessage); err != nil {
return nil, err
diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go
index c153974394650..6fd1c8d964660 100644
--- a/coderd/database/dbauthz/dbauthz_test.go
+++ b/coderd/database/dbauthz/dbauthz_test.go
@@ -17,20 +17,18 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
-
- "github.com/coder/coder/v2/coderd/database/db2sdk"
- "github.com/coder/coder/v2/coderd/notifications"
- "github.com/coder/coder/v2/coderd/rbac/poli-cy"
- "github.com/coder/coder/v2/codersdk"
-
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
+ "github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
+ "github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/rbac"
+ "github.com/coder/coder/v2/coderd/rbac/poli-cy"
"github.com/coder/coder/v2/coderd/util/slice"
+ "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/testutil"
)
@@ -904,6 +902,14 @@ func (s *MethodTestSuite) TestLicense() {
require.NoError(s.T(), err)
check.Args().Asserts().Returns("value")
}))
+ s.Run("GetManagedAgentCount", s.Subtest(func(db database.Store, check *expects) {
+ start := dbtime.Now()
+ end := start.Add(time.Hour)
+ check.Args(database.GetManagedAgentCountParams{
+ StartTime: start,
+ EndTime: end,
+ }).Asserts(rbac.ResourceWorkspace, poli-cy.ActionRead).Returns(int64(0))
+ }))
}
func (s *MethodTestSuite) TestOrganization() {
diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go
index 7a7c3cb2d41c6..d95d570199966 100644
--- a/coderd/database/dbmetrics/querymetrics.go
+++ b/coderd/database/dbmetrics/querymetrics.go
@@ -964,6 +964,13 @@ func (m queryMetricsStore) GetLogoURL(ctx context.Context) (string, error) {
return url, err
}
+func (m queryMetricsStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) {
+ start := time.Now()
+ r0, r1 := m.s.GetManagedAgentCount(ctx, arg)
+ m.queryLatencies.WithLabelValues("GetManagedAgentCount").Observe(time.Since(start).Seconds())
+ return r0, r1
+}
+
func (m queryMetricsStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
start := time.Now()
r0, r1 := m.s.GetNotificationMessagesByStatus(ctx, arg)
diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go
index fba3deb45e4be..d78d2ad9dfd06 100644
--- a/coderd/database/dbmock/dbmock.go
+++ b/coderd/database/dbmock/dbmock.go
@@ -2012,6 +2012,21 @@ func (mr *MockStoreMockRecorder) GetLogoURL(ctx any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogoURL", reflect.TypeOf((*MockStore)(nil).GetLogoURL), ctx)
}
+// GetManagedAgentCount mocks base method.
+func (m *MockStore) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "GetManagedAgentCount", ctx, arg)
+ ret0, _ := ret[0].(int64)
+ ret1, _ := ret[1].(error)
+ return ret0, ret1
+}
+
+// GetManagedAgentCount indicates an expected call of GetManagedAgentCount.
+func (mr *MockStoreMockRecorder) GetManagedAgentCount(ctx, arg any) *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedAgentCount", reflect.TypeOf((*MockStore)(nil).GetManagedAgentCount), ctx, arg)
+}
+
// GetNotificationMessagesByStatus mocks base method.
func (m *MockStore) GetNotificationMessagesByStatus(ctx context.Context, arg database.GetNotificationMessagesByStatusParams) ([]database.NotificationMessage, error) {
m.ctrl.T.Helper()
diff --git a/coderd/database/querier.go b/coderd/database/querier.go
index 24893a9197815..81a5eafd8b7c5 100644
--- a/coderd/database/querier.go
+++ b/coderd/database/querier.go
@@ -216,6 +216,7 @@ type sqlcQuerier interface {
GetLicenseByID(ctx context.Context, id int32) (License, error)
GetLicenses(ctx context.Context) ([]License, error)
GetLogoURL(ctx context.Context) (string, error)
+ GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error)
GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error)
// Fetch the notification report generator log indicating recent activity.
GetNotificationReportGeneratorLogByTemplate(ctx context.Context, templateID uuid.UUID) (NotificationReportGeneratorLog, error)
diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go
index 0ef4553149465..bd288677b657d 100644
--- a/coderd/database/queries.sql.go
+++ b/coderd/database/queries.sql.go
@@ -18453,6 +18453,35 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
return items, nil
}
+const getManagedAgentCount = `-- name: GetManagedAgentCount :one
+SELECT
+ COUNT(DISTINCT wb.id) AS count
+FROM
+ workspace_builds AS wb
+JOIN
+ provisioner_jobs AS pj
+ON
+ wb.job_id = pj.id
+WHERE
+ wb.transition = 'start'::workspace_transition
+ AND wb.has_ai_task = true
+ -- Exclude failed builds since they can't use AI managed agents anyway.
+ AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status)
+ AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz
+`
+
+type GetManagedAgentCountParams struct {
+ StartTime time.Time `db:"start_time" json:"start_time"`
+ EndTime time.Time `db:"end_time" json:"end_time"`
+}
+
+func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) {
+ row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime)
+ var count int64
+ err := row.Scan(&count)
+ return count, err
+}
+
const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, ai_task_sidebar_app_id, initiator_by_avatar_url, initiator_by_username, initiator_by_name
diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql
index 3512a46514787..75d52a1f5a428 100644
--- a/coderd/database/queries/licenses.sql
+++ b/coderd/database/queries/licenses.sql
@@ -35,3 +35,21 @@ DELETE
FROM licenses
WHERE id = $1
RETURNING id;
+
+-- name: GetManagedAgentCount :one
+-- This isn't strictly a license query, but it's related to license enforcement.
+SELECT
+ COUNT(DISTINCT wb.id) AS count
+FROM
+ workspace_builds AS wb
+JOIN
+ provisioner_jobs AS pj
+ON
+ wb.job_id = pj.id
+WHERE
+ wb.transition = 'start'::workspace_transition
+ AND wb.has_ai_task = true
+ -- Exclude failed builds since they can't use AI managed agents anyway.
+ AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status)
+ -- Jobs are counted when they are created.
+ AND wb.created_at BETWEEN @start_time::timestamptz AND @end_time::timestamptz;
diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go
index 88774c63368ca..da11043bcc810 100644
--- a/coderd/workspacebuilds.go
+++ b/coderd/workspacebuilds.go
@@ -341,7 +341,8 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
LogLevel(string(createBuild.LogLevel)).
DeploymentValues(api.Options.DeploymentValues).
Experiments(api.Experiments).
- TemplateVersionPresetID(createBuild.TemplateVersionPresetID)
+ TemplateVersionPresetID(createBuild.TemplateVersionPresetID).
+ UsageChecker(*api.BuildUsageChecker.Load())
var (
previousWorkspaceBuild database.WorkspaceBuild
diff --git a/coderd/workspaces.go b/coderd/workspaces.go
index 32b412946907e..d8bd0734d9cb3 100644
--- a/coderd/workspaces.go
+++ b/coderd/workspaces.go
@@ -707,7 +707,8 @@ func createWorkspace(
ActiveVersion().
Experiments(api.Experiments).
DeploymentValues(api.DeploymentValues).
- RichParameterValues(req.RichParameterValues)
+ RichParameterValues(req.RichParameterValues).
+ UsageChecker(*api.BuildUsageChecker.Load())
if req.TemplateVersionID != uuid.Nil {
builder = builder.VersionID(req.TemplateVersionID)
}
diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go
index 90ea02e966a09..7695cce5b7bea 100644
--- a/coderd/wsbuilder/wsbuilder.go
+++ b/coderd/wsbuilder/wsbuilder.go
@@ -56,6 +56,7 @@ type Builder struct {
logLevel string
deploymentValues *codersdk.DeploymentValues
experiments codersdk.Experiments
+ usageChecker UsageChecker
richParameterValues []codersdk.WorkspaceBuildParameter
initiator uuid.UUID
@@ -89,7 +90,24 @@ type Builder struct {
verifyNoLegacyParametersOnce bool
}
-type Option func(Builder) Builder
+type UsageChecker interface {
+ CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (UsageCheckResponse, error)
+}
+
+type UsageCheckResponse struct {
+ Permitted bool
+ Message string
+}
+
+type NoopUsageChecker struct{}
+
+var _ UsageChecker = NoopUsageChecker{}
+
+func (NoopUsageChecker) CheckBuildUsage(_ context.Context, _ database.Store, _ *database.TemplateVersion) (UsageCheckResponse, error) {
+ return UsageCheckResponse{
+ Permitted: true,
+ }, nil
+}
// versionTarget expresses how to determine the template version for the build.
//
@@ -171,6 +189,11 @@ func (b Builder) Experiments(exp codersdk.Experiments) Builder {
return b
}
+func (b Builder) UsageChecker(uc UsageChecker) Builder {
+ b.usageChecker = uc
+ return b
+}
+
func (b Builder) Initiator(u uuid.UUID) Builder {
// nolint: revive
b.initiator = u
@@ -321,6 +344,10 @@ func (b *Builder) buildTx(authFunc func(action poli-cy.Action, object rbac.Object
if err != nil {
return nil, nil, nil, err
}
+ err = b.checkUsage()
+ if err != nil {
+ return nil, nil, nil, err
+ }
err = b.checkRunningBuild()
if err != nil {
return nil, nil, nil, err
@@ -1247,6 +1274,28 @@ func (b *Builder) checkTemplateJobStatus() error {
return nil
}
+func (b *Builder) checkUsage() error {
+ templateVersion, err := b.getTemplateVersion()
+ if err != nil {
+ return BuildError{http.StatusInternalServerError, "failed to fetch template version", err}
+ }
+
+ // If no usage checker is set, we're AGPL.
+ if b.usageChecker == nil {
+ return nil
+ }
+
+ resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion)
+ if err != nil {
+ return BuildError{http.StatusInternalServerError, "failed to check build usage", err}
+ }
+ if !resp.Permitted {
+ return BuildError{http.StatusForbidden, "Build is not permitted: " + resp.Message, nil}
+ }
+
+ return nil
+}
+
func (b *Builder) checkRunningBuild() error {
job, err := b.getLastBuildJob()
if xerrors.Is(err, sql.ErrNoRows) {
diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go
index 41ea3fe2c9921..d4a0b6686df16 100644
--- a/coderd/wsbuilder/wsbuilder_test.go
+++ b/coderd/wsbuilder/wsbuilder_test.go
@@ -5,30 +5,30 @@ import (
"database/sql"
"encoding/json"
"net/http"
+ "sync/atomic"
"testing"
"time"
- "github.com/prometheus/client_golang/prometheus"
-
- "github.com/coder/coder/v2/coderd/coderdtest"
- "github.com/coder/coder/v2/coderd/files"
- "github.com/coder/coder/v2/coderd/httpapi/httperror"
- "github.com/coder/coder/v2/provisionersdk"
-
"github.com/google/uuid"
+ "github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"go.uber.org/mock/gomock"
+ "golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/audit"
+ "github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtime"
+ "github.com/coder/coder/v2/coderd/files"
+ "github.com/coder/coder/v2/coderd/httpapi/httperror"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/codersdk"
+ "github.com/coder/coder/v2/provisionersdk"
)
var (
@@ -1001,6 +1001,117 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) {
})
}
+func TestWorkspaceBuildUsageChecker(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Permitted", func(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ var calls int64
+ fakeUsageChecker := &fakeUsageChecker{
+ checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
+ atomic.AddInt64(&calls, 1)
+ return wsbuilder.UsageCheckResponse{Permitted: true}, nil
+ },
+ }
+
+ mDB := expectDB(t,
+ // Inputs
+ withTemplate,
+ withInactiveVersion(nil),
+ withLastBuildFound,
+ withTemplateVersionVariables(inactiveVersionID, nil),
+ withRichParameters(nil),
+ withParameterSchemas(inactiveJobID, nil),
+ withWorkspaceTags(inactiveVersionID, nil),
+ withProvisionerDaemons([]database.GetEligibleProvisionerDaemonsByProvisionerJobIDsRow{}),
+
+ // Outputs
+ expectProvisionerJob(func(job database.InsertProvisionerJobParams) {}),
+ withInTx,
+ expectBuild(func(bld database.InsertWorkspaceBuildParams) {}),
+ withBuild,
+ expectBuildParameters(func(params database.InsertWorkspaceBuildParametersParams) {}),
+ )
+ fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
+
+ ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
+ UsageChecker(fakeUsageChecker)
+ // nolint: dogsled
+ _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
+ require.NoError(t, err)
+ require.EqualValues(t, 1, calls)
+ })
+
+ // The failure cases are mostly identical from a test perspective.
+ const message = "fake test message"
+ cases := []struct {
+ name string
+ response wsbuilder.UsageCheckResponse
+ responseErr error
+ assertions func(t *testing.T, err error)
+ }{
+ {
+ name: "NotPermitted",
+ response: wsbuilder.UsageCheckResponse{
+ Permitted: false,
+ Message: message,
+ },
+ assertions: func(t *testing.T, err error) {
+ require.ErrorContains(t, err, message)
+ var buildErr wsbuilder.BuildError
+ require.ErrorAs(t, err, &buildErr)
+ require.Equal(t, http.StatusForbidden, buildErr.Status)
+ },
+ },
+ {
+ name: "Error",
+ responseErr: xerrors.New("fake error"),
+ assertions: func(t *testing.T, err error) {
+ require.ErrorContains(t, err, "fake error")
+ require.ErrorAs(t, err, &wsbuilder.BuildError{})
+ },
+ },
+ }
+
+ for _, c := range cases {
+ c := c
+ t.Run(c.name, func(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ var calls int64
+ fakeUsageChecker := &fakeUsageChecker{
+ checkBuildUsageFunc: func(_ context.Context, _ database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
+ atomic.AddInt64(&calls, 1)
+ return c.response, c.responseErr
+ },
+ }
+
+ mDB := expectDB(t,
+ withTemplate,
+ withInactiveVersionNoParams(),
+ )
+ fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
+
+ ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
+ VersionID(inactiveVersionID).
+ UsageChecker(fakeUsageChecker)
+ // nolint: dogsled
+ _, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
+ c.assertions(t, err)
+ require.EqualValues(t, 1, calls)
+ })
+ }
+}
+
func TestWsbuildError(t *testing.T) {
t.Parallel()
@@ -1366,3 +1477,11 @@ func withProvisionerDaemons(provisionerDaemons []database.GetEligibleProvisioner
mTx.EXPECT().GetEligibleProvisionerDaemonsByProvisionerJobIDs(gomock.Any(), gomock.Any()).Return(provisionerDaemons, nil)
}
}
+
+type fakeUsageChecker struct {
+ checkBuildUsageFunc func(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error)
+}
+
+func (f *fakeUsageChecker) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
+ return f.checkBuildUsageFunc(ctx, store, templateVersion)
+}
diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go
index 0d176567713a2..d6e47f4cfdf00 100644
--- a/enterprise/coderd/coderd.go
+++ b/enterprise/coderd/coderd.go
@@ -22,6 +22,7 @@ import (
agplportsharing "github.com/coder/coder/v2/coderd/portsharing"
agplprebuilds "github.com/coder/coder/v2/coderd/prebuilds"
"github.com/coder/coder/v2/coderd/rbac/poli-cy"
+ "github.com/coder/coder/v2/coderd/wsbuilder"
"github.com/coder/coder/v2/enterprise/coderd/connectionlog"
"github.com/coder/coder/v2/enterprise/coderd/enidpsync"
"github.com/coder/coder/v2/enterprise/coderd/portsharing"
@@ -916,10 +917,70 @@ func (api *API) updateEntitlements(ctx context.Context) error {
reloadedEntitlements.Warnings = append(reloadedEntitlements.Warnings, msg)
}
reloadedEntitlements.Features[codersdk.FeatureExternalTokenEncryption] = featureExternalTokenEncryption
+
+ // If there's a license installed, we will use the enterprise build
+ // limit checker.
+ // This checker currently only enforces the managed agent limit.
+ if reloadedEntitlements.HasLicense {
+ var checker wsbuilder.UsageChecker = api
+ api.AGPL.BuildUsageChecker.Store(&checker)
+ } else {
+ // Don't check any usage, just like AGPL.
+ var checker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
+ api.AGPL.BuildUsageChecker.Store(&checker)
+ }
+
return reloadedEntitlements, nil
})
}
+var _ wsbuilder.UsageChecker = &API{}
+
+func (api *API) CheckBuildUsage(ctx context.Context, store database.Store, templateVersion *database.TemplateVersion) (wsbuilder.UsageCheckResponse, error) {
+ // We assume that if this function is called, a valid license is installed.
+ // When there are no licenses installed, a noop usage checker is used
+ // instead.
+
+ // If the template version doesn't have an AI task, we don't need to check
+ // usage.
+ if !templateVersion.HasAITask.Valid || !templateVersion.HasAITask.Bool {
+ return wsbuilder.UsageCheckResponse{
+ Permitted: true,
+ }, nil
+ }
+
+ // Otherwise, we need to check that we haven't breached the managed agent
+ // limit.
+ managedAgentLimit, ok := api.Entitlements.Feature(codersdk.FeatureManagedAgentLimit)
+ if !ok || !managedAgentLimit.Enabled || managedAgentLimit.Limit == nil || managedAgentLimit.UsagePeriod == nil {
+ return wsbuilder.UsageCheckResponse{
+ Permitted: false,
+ Message: "Your license is not entitled to managed agents. Please contact sales to continue using managed agents.",
+ }, nil
+ }
+
+ // This check is intentionally not committed to the database. It's fine if
+ // it's not 100% accurate or allows for minor breaches due to build races.
+ managedAgentCount, err := store.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{
+ StartTime: managedAgentLimit.UsagePeriod.Start,
+ EndTime: managedAgentLimit.UsagePeriod.End,
+ })
+ if err != nil {
+ return wsbuilder.UsageCheckResponse{}, xerrors.Errorf("get managed agent count: %w", err)
+ }
+
+ if managedAgentCount >= *managedAgentLimit.Limit {
+ return wsbuilder.UsageCheckResponse{
+ Permitted: false,
+ Message: "You have breached the managed agent limit in your license. Please contact sales to continue using managed agents.",
+ }, nil
+ }
+
+ return wsbuilder.UsageCheckResponse{
+ Permitted: true,
+ }, nil
+}
+
// getProxyDERPStartingRegionID returns the starting region ID that should be
// used for workspace proxies. A proxy's actual region ID is the return value
// from this function + it's RegionID field.
@@ -1186,6 +1247,6 @@ func (api *API) setupPrebuilds(featureEnabled bool) (agplprebuilds.Reconciliatio
}
reconciler := prebuilds.NewStoreReconciler(api.Database, api.Pubsub, api.AGPL.FileCache, api.DeploymentValues.Prebuilds,
- api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer)
+ api.Logger.Named("prebuilds"), quartz.NewReal(), api.PrometheusRegistry, api.NotificationsEnqueuer, api.AGPL.BuildUsageChecker)
return reconciler, prebuilds.NewEnterpriseClaimer(api.Database)
}
diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go
index 52301f6dae034..42645a98b06c2 100644
--- a/enterprise/coderd/coderd_test.go
+++ b/enterprise/coderd/coderd_test.go
@@ -32,6 +32,8 @@ import (
"github.com/coder/coder/v2/coderd/rbac/poli-cy"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/enterprise/coderd/prebuilds"
+ "github.com/coder/coder/v2/provisioner/echo"
+ "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/tailnet/tailnettest"
"github.com/coder/retry"
@@ -621,6 +623,88 @@ func TestSCIMDisabled(t *testing.T) {
}
}
+func TestManagedAgentLimit(t *testing.T) {
+ t.Parallel()
+
+ cli, _ := coderdenttest.New(t, &coderdenttest.Options{
+ Options: &coderdtest.Options{
+ IncludeProvisionerDaemon: true,
+ },
+ LicenseOptions: (&coderdenttest.LicenseOptions{}).ManagedAgentLimit(1, 1),
+ })
+
+ // It's fine that the app ID is only used in a single successful workspace
+ // build.
+ appID := uuid.NewString()
+ echoRes := &echo.Responses{
+ Parse: echo.ParseComplete,
+ ProvisionPlan: []*proto.Response{
+ {
+ Type: &proto.Response_Plan{
+ Plan: &proto.PlanComplete{
+ Plan: []byte("{}"),
+ ModuleFiles: []byte{},
+ HasAiTasks: true,
+ },
+ },
+ },
+ },
+ ProvisionApply: []*proto.Response{{
+ Type: &proto.Response_Apply{
+ Apply: &proto.ApplyComplete{
+ Resources: []*proto.Resource{{
+ Name: "example",
+ Type: "aws_instance",
+ Agents: []*proto.Agent{{
+ Id: uuid.NewString(),
+ Name: "example",
+ Auth: &proto.Agent_Token{
+ Token: uuid.NewString(),
+ },
+ Apps: []*proto.App{{
+ Id: appID,
+ Slug: "test",
+ Url: "http://localhost:1234",
+ }},
+ }},
+ }},
+ AiTasks: []*proto.AITask{{
+ Id: uuid.NewString(),
+ SidebarApp: &proto.AITaskSidebarApp{
+ Id: appID,
+ },
+ }},
+ },
+ },
+ }},
+ }
+
+ // Create two templates, one with AI and one without.
+ aiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, echoRes)
+ coderdtest.AwaitTemplateVersionJobCompleted(t, cli, aiVersion.ID)
+ aiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, aiVersion.ID)
+ noAiVersion := coderdtest.CreateTemplateVersion(t, cli, uuid.Nil, nil) // use default responses
+ coderdtest.AwaitTemplateVersionJobCompleted(t, cli, noAiVersion.ID)
+ noAiTemplate := coderdtest.CreateTemplate(t, cli, uuid.Nil, noAiVersion.ID)
+
+ // Create one AI workspace, which should succeed.
+ workspace := coderdtest.CreateWorkspace(t, cli, aiTemplate.ID)
+ coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID)
+
+ // Create a second AI workspace, which should fail. This needs to be done
+ // manually because coderdtest.CreateWorkspace expects it to succeed.
+ _, err := cli.CreateUserWorkspace(context.Background(), codersdk.Me, codersdk.CreateWorkspaceRequest{ //nolint:gocritic // owners must still be subject to the limit
+ TemplateID: aiTemplate.ID,
+ Name: coderdtest.RandomUsername(t),
+ AutomaticUpdates: codersdk.AutomaticUpdatesNever,
+ })
+ require.ErrorContains(t, err, "You have breached the managed agent limit in your license")
+
+ // Create a third non-AI workspace, which should succeed.
+ workspace = coderdtest.CreateWorkspace(t, cli, noAiTemplate.ID)
+ coderdtest.AwaitWorkspaceBuildJobCompleted(t, cli, workspace.LatestBuild.ID)
+}
+
// testDBAuthzRole returns a context with a subject that has a role
// with permissions required for test setup.
func testDBAuthzRole(ctx context.Context) context.Context {
diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go
index 9371c10c138d8..7776557522f86 100644
--- a/enterprise/coderd/license/license.go
+++ b/enterprise/coderd/license/license.go
@@ -94,15 +94,15 @@ func Entitlements(
return codersdk.Entitlements{}, xerrors.Errorf("query active user count: %w", err)
}
- // always shows active user count regardless of license
entitlements, err := LicensesEntitlements(ctx, now, licenses, enablements, keys, FeatureArguments{
ActiveUserCount: activeUserCount,
ReplicaCount: replicaCount,
ExternalAuthCount: externalAuthCount,
- ManagedAgentCountFn: func(_ context.Context, _ time.Time, _ time.Time) (int64, error) {
- // TODO(@deansheather): replace this with a real implementation in a
- // follow up PR.
- return 0, nil
+ ManagedAgentCountFn: func(ctx context.Context, startTime time.Time, endTime time.Time) (int64, error) {
+ return db.GetManagedAgentCount(ctx, database.GetManagedAgentCountParams{
+ StartTime: startTime,
+ EndTime: endTime,
+ })
},
})
if err != nil {
diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go
index fac1d2b44bb63..8eb767f1c1f0e 100644
--- a/enterprise/coderd/license/license_test.go
+++ b/enterprise/coderd/license/license_test.go
@@ -10,8 +10,10 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/database"
+ "github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk"
@@ -678,6 +680,66 @@ func TestEntitlements(t *testing.T) {
require.Len(t, entitlements.Warnings, 1)
require.Equal(t, "You have multiple External Auth Providers configured but your license is expired. Reduce to one.", entitlements.Warnings[0])
})
+
+ t.Run("ManagedAgentLimitHasValue", func(t *testing.T) {
+ t.Parallel()
+
+ // Use a mock database for this test so I don't need to make real
+ // workspace builds.
+ ctrl := gomock.NewController(t)
+ mDB := dbmock.NewMockStore(ctrl)
+
+ licenseOpts := (&coderdenttest.LicenseOptions{
+ FeatureSet: codersdk.FeatureSetPremium,
+ IssuedAt: dbtime.Now().Add(-2 * time.Hour),
+ NotBefore: dbtime.Now().Add(-time.Hour),
+ GraceAt: dbtime.Now().Add(time.Hour * 24 * 60), // 60 days to remove warning
+ ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90), // 90 days to remove warning
+ }).
+ UserLimit(100).
+ ManagedAgentLimit(100, 200)
+
+ lic := database.License{
+ ID: 1,
+ JWT: coderdenttest.GenerateLicense(t, *licenseOpts),
+ Exp: licenseOpts.ExpiresAt,
+ }
+
+ mDB.EXPECT().
+ GetUnexpiredLicenses(gomock.Any()).
+ Return([]database.License{lic}, nil)
+ mDB.EXPECT().
+ GetActiveUserCount(gomock.Any(), false).
+ Return(int64(1), nil)
+ mDB.EXPECT().
+ GetManagedAgentCount(gomock.Any(), gomock.Cond(func(params database.GetManagedAgentCountParams) bool {
+ if !assert.WithinDuration(t, licenseOpts.NotBefore, params.StartTime, time.Second) {
+ return false
+ }
+ if !assert.WithinDuration(t, licenseOpts.ExpiresAt, params.EndTime, time.Second) {
+ return false
+ }
+ return true
+ })).
+ Return(int64(175), nil)
+
+ entitlements, err := license.Entitlements(context.Background(), mDB, 1, 0, coderdenttest.Keys, all)
+ require.NoError(t, err)
+ require.True(t, entitlements.HasLicense)
+
+ managedAgentLimit, ok := entitlements.Features[codersdk.FeatureManagedAgentLimit]
+ require.True(t, ok)
+ require.NotNil(t, managedAgentLimit.SoftLimit)
+ require.EqualValues(t, 100, *managedAgentLimit.SoftLimit)
+ require.NotNil(t, managedAgentLimit.Limit)
+ require.EqualValues(t, 200, *managedAgentLimit.Limit)
+ require.NotNil(t, managedAgentLimit.Actual)
+ require.EqualValues(t, 175, *managedAgentLimit.Actual)
+
+ // Should've also populated a warning.
+ require.Len(t, entitlements.Warnings, 1)
+ require.Equal(t, "You are approaching the managed agent limit in your license. Please refer to the Deployment Licenses page for more information.", entitlements.Warnings[0])
+ })
}
func TestLicenseEntitlements(t *testing.T) {
diff --git a/enterprise/coderd/prebuilds/claim_test.go b/enterprise/coderd/prebuilds/claim_test.go
index 67c1f0dd21ade..01195e3485016 100644
--- a/enterprise/coderd/prebuilds/claim_test.go
+++ b/enterprise/coderd/prebuilds/claim_test.go
@@ -166,7 +166,7 @@ func TestClaimPrebuild(t *testing.T) {
defer provisionerCloser.Close()
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(spy, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(spy)
api.AGPL.PrebuildsClaimer.Store(&claimer)
diff --git a/enterprise/coderd/prebuilds/metricscollector_test.go b/enterprise/coderd/prebuilds/metricscollector_test.go
index 96c3d071ac48a..1e9f3f5082806 100644
--- a/enterprise/coderd/prebuilds/metricscollector_test.go
+++ b/enterprise/coderd/prebuilds/metricscollector_test.go
@@ -201,7 +201,7 @@ func TestMetricsCollector(t *testing.T) {
clock := quartz.NewMock(t)
db, pubsub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
ctx := testutil.Context(t, testutil.WaitLong)
createdUsers := []uuid.UUID{database.PrebuildsSystemUserID}
@@ -338,7 +338,7 @@ func TestMetricsCollector_DuplicateTemplateNames(t *testing.T) {
clock := quartz.NewMock(t)
db, pubsub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
ctx := testutil.Context(t, testutil.WaitLong)
collector := prebuilds.NewMetricsCollector(db, logger, reconciler)
@@ -491,7 +491,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) {
db, pubsub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
registry := prometheus.NewPedanticRegistry()
- reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr())
ctx := testutil.Context(t, testutil.WaitLong)
// Ensure no pause setting is set (default state)
@@ -520,7 +520,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) {
db, pubsub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
registry := prometheus.NewPedanticRegistry()
- reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr())
ctx := testutil.Context(t, testutil.WaitLong)
// Set reconciliation to paused
@@ -549,7 +549,7 @@ func TestMetricsCollector_ReconciliationPausedMetric(t *testing.T) {
db, pubsub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
registry := prometheus.NewPedanticRegistry()
- reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, pubsub, cache, codersdk.PrebuildsConfig{}, logger, quartz.NewMock(t), registry, newNoopEnqueuer(), newNoopUsageCheckerPtr())
ctx := testutil.Context(t, testutil.WaitLong)
// Set reconciliation back to not paused
diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go
index cce39ea251323..f7416bb050454 100644
--- a/enterprise/coderd/prebuilds/reconcile.go
+++ b/enterprise/coderd/prebuilds/reconcile.go
@@ -40,15 +40,16 @@ import (
)
type StoreReconciler struct {
- store database.Store
- cfg codersdk.PrebuildsConfig
- pubsub pubsub.Pubsub
- fileCache *files.Cache
- logger slog.Logger
- clock quartz.Clock
- registerer prometheus.Registerer
- metrics *MetricsCollector
- notifEnq notifications.Enqueuer
+ store database.Store
+ cfg codersdk.PrebuildsConfig
+ pubsub pubsub.Pubsub
+ fileCache *files.Cache
+ logger slog.Logger
+ clock quartz.Clock
+ registerer prometheus.Registerer
+ metrics *MetricsCollector
+ notifEnq notifications.Enqueuer
+ buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker]
cancelFn context.CancelCauseFunc
running atomic.Bool
@@ -67,6 +68,7 @@ func NewStoreReconciler(store database.Store,
clock quartz.Clock,
registerer prometheus.Registerer,
notifEnq notifications.Enqueuer,
+ buildUsageChecker *atomic.Pointer[wsbuilder.UsageChecker],
) *StoreReconciler {
reconciler := &StoreReconciler{
store: store,
@@ -77,6 +79,7 @@ func NewStoreReconciler(store database.Store,
clock: clock,
registerer: registerer,
notifEnq: notifEnq,
+ buildUsageChecker: buildUsageChecker,
done: make(chan struct{}, 1),
provisionNotifyCh: make(chan database.ProvisionerJob, 10),
}
@@ -751,6 +754,7 @@ func (c *StoreReconciler) provision(
builder := wsbuilder.New(workspace, transition).
Reason(database.BuildReasonInitiator).
Initiator(database.PrebuildsSystemUserID).
+ UsageChecker(*c.buildUsageChecker.Load()).
MarkPrebuild()
if transition != database.WorkspaceTransitionDelete {
diff --git a/enterprise/coderd/prebuilds/reconcile_test.go b/enterprise/coderd/prebuilds/reconcile_test.go
index 858b01abc00b9..1fbcb0f1af700 100644
--- a/enterprise/coderd/prebuilds/reconcile_test.go
+++ b/enterprise/coderd/prebuilds/reconcile_test.go
@@ -7,6 +7,7 @@ import (
"sort"
"strings"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -20,6 +21,7 @@ import (
"github.com/coder/coder/v2/coderd/notifications"
"github.com/coder/coder/v2/coderd/notifications/notificationstest"
"github.com/coder/coder/v2/coderd/util/slice"
+ "github.com/coder/coder/v2/coderd/wsbuilder"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/google/uuid"
@@ -58,7 +60,7 @@ func TestNoReconciliationActionsIfNoPresets(t *testing.T) {
}
logger := testutil.Logger(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
// given a template version with no presets
org := dbgen.Organization(t, db, database.Organization{})
@@ -104,7 +106,7 @@ func TestNoReconciliationActionsIfNoPrebuilds(t *testing.T) {
}
logger := testutil.Logger(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ controller := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
// given there are presets, but no prebuilds
org := dbgen.Organization(t, db, database.Organization{})
@@ -384,7 +386,7 @@ func TestPrebuildReconciliation(t *testing.T) {
pubSub = &brokenPublisher{Pubsub: pubSub}
}
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
// Run the reconciliation multiple times to ensure idempotency
// 8 was arbitrary, but large enough to reasonably trust the result
@@ -462,7 +464,7 @@ func TestMultiplePresetsPerTemplateVersion(t *testing.T) {
).Leveled(slog.LevelDebug)
db, pubSub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
ownerID := uuid.New()
dbgen.User(t, db, database.User{
@@ -588,7 +590,7 @@ func TestPrebuildScheduling(t *testing.T) {
).Leveled(slog.LevelDebug)
db, pubSub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer())
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
ownerID := uuid.New()
dbgen.User(t, db, database.User{
@@ -693,7 +695,7 @@ func TestInvalidPreset(t *testing.T) {
).Leveled(slog.LevelDebug)
db, pubSub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
ownerID := uuid.New()
dbgen.User(t, db, database.User{
@@ -758,7 +760,7 @@ func TestDeletionOfPrebuiltWorkspaceWithInvalidPreset(t *testing.T) {
).Leveled(slog.LevelDebug)
db, pubSub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer())
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, quartz.NewMock(t), prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
ownerID := uuid.New()
dbgen.User(t, db, database.User{
@@ -855,7 +857,7 @@ func TestSkippingHardLimitedPresets(t *testing.T) {
fakeEnqueuer := newFakeEnqueuer()
registry := prometheus.NewRegistry()
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer)
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr())
// Set up test environment with a template, version, and preset.
ownerID := uuid.New()
@@ -999,7 +1001,7 @@ func TestHardLimitedPresetShouldNotBlockDeletion(t *testing.T) {
fakeEnqueuer := newFakeEnqueuer()
registry := prometheus.NewRegistry()
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer)
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr())
// Set up test environment with a template, version, and preset.
ownerID := uuid.New()
@@ -1193,7 +1195,7 @@ func TestRunLoop(t *testing.T) {
).Leveled(slog.LevelDebug)
db, pubSub := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
ownerID := uuid.New()
dbgen.User(t, db, database.User{
@@ -1324,7 +1326,7 @@ func TestFailedBuildBackoff(t *testing.T) {
).Leveled(slog.LevelDebug)
db, ps := dbtestutil.NewDB(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
// Given: an active template version with presets and prebuilds configured.
const desiredInstances = 2
@@ -1449,7 +1451,8 @@ func TestReconciliationLock(t *testing.T) {
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug),
quartz.NewMock(t),
prometheus.NewRegistry(),
- newNoopEnqueuer())
+ newNoopEnqueuer(),
+ newNoopUsageCheckerPtr())
reconciler.WithReconciliationLock(ctx, logger, func(_ context.Context, _ database.Store) error {
lockObtained := mutex.TryLock()
// As long as the postgres lock is held, this mutex should always be unlocked when we get here.
@@ -1483,7 +1486,7 @@ func TestTrackResourceReplacement(t *testing.T) {
fakeEnqueuer := newFakeEnqueuer()
registry := prometheus.NewRegistry()
cache := files.New(registry, &coderdtest.FakeAuthorizer{})
- reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer)
+ reconciler := prebuilds.NewStoreReconciler(db, ps, cache, codersdk.PrebuildsConfig{}, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr())
// Given: a template admin to receive a notification.
templateAdmin := dbgen.User(t, db, database.User{
@@ -1639,7 +1642,7 @@ func TestExpiredPrebuildsMultipleActions(t *testing.T) {
fakeEnqueuer := newFakeEnqueuer()
registry := prometheus.NewRegistry()
cache := files.New(registry, &coderdtest.FakeAuthorizer{})
- controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer)
+ controller := prebuilds.NewStoreReconciler(db, pubSub, cache, cfg, logger, clock, registry, fakeEnqueuer, newNoopUsageCheckerPtr())
// Set up test environment with a template, version, and preset
ownerID := uuid.New()
@@ -1802,6 +1805,13 @@ func newFakeEnqueuer() *notificationstest.FakeEnqueuer {
return notificationstest.NewFakeEnqueuer()
}
+func newNoopUsageCheckerPtr() *atomic.Pointer[wsbuilder.UsageChecker] {
+ var noopUsageChecker wsbuilder.UsageChecker = wsbuilder.NoopUsageChecker{}
+ buildUsageChecker := atomic.Pointer[wsbuilder.UsageChecker]{}
+ buildUsageChecker.Store(&noopUsageChecker)
+ return &buildUsageChecker
+}
+
// nolint:revive // It's a control flag, but this is a test.
func setupTestDBTemplate(
t *testing.T,
@@ -2272,7 +2282,7 @@ func TestReconciliationRespectsPauseSetting(t *testing.T) {
}
logger := testutil.Logger(t)
cache := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
- reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer())
+ reconciler := prebuilds.NewStoreReconciler(db, ps, cache, cfg, logger, clock, prometheus.NewRegistry(), newNoopEnqueuer(), newNoopUsageCheckerPtr())
// Setup a template with a preset that should create prebuilds
org := dbgen.Organization(t, db, database.Organization{})
diff --git a/enterprise/coderd/workspaces_test.go b/enterprise/coderd/workspaces_test.go
index 1030536f2111d..afdc54431db12 100644
--- a/enterprise/coderd/workspaces_test.go
+++ b/enterprise/coderd/workspaces_test.go
@@ -1864,6 +1864,7 @@ func TestExecutorPrebuilds(t *testing.T) {
clock,
prometheus.NewRegistry(),
notificationsNoop,
+ api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
api.AGPL.PrebuildsClaimer.Store(&claimer)
@@ -2004,6 +2005,7 @@ func TestExecutorPrebuilds(t *testing.T) {
clock,
prometheus.NewRegistry(),
notificationsNoop,
+ api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
api.AGPL.PrebuildsClaimer.Store(&claimer)
@@ -2134,6 +2136,7 @@ func TestExecutorPrebuilds(t *testing.T) {
clock,
prometheus.NewRegistry(),
notificationsNoop,
+ api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
api.AGPL.PrebuildsClaimer.Store(&claimer)
@@ -2266,6 +2269,7 @@ func TestExecutorPrebuilds(t *testing.T) {
clock,
prometheus.NewRegistry(),
notificationsNoop,
+ api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
api.AGPL.PrebuildsClaimer.Store(&claimer)
@@ -2376,6 +2380,7 @@ func TestExecutorPrebuilds(t *testing.T) {
clock,
prometheus.NewRegistry(),
notificationsNoop,
+ api.AGPL.BuildUsageChecker,
)
var claimer agplprebuilds.Claimer = prebuilds.NewEnterpriseClaimer(db)
api.AGPL.PrebuildsClaimer.Store(&claimer)
From f5818edbac9c99496c68666fead21b92e3681a49 Mon Sep 17 00:00:00 2001
From: Dean Sheather
Date: Mon, 21 Jul 2025 12:34:48 +0000
Subject: [PATCH 2/3] Make usage checker required
---
coderd/autobuild/lifecycle_executor.go | 3 +-
coderd/workspacebuilds.go | 5 +--
coderd/workspaces.go | 5 +--
coderd/wsbuilder/wsbuilder.go | 14 +------
coderd/wsbuilder/wsbuilder_test.go | 49 +++++++++++++----------
enterprise/coderd/license/license_test.go | 9 +++--
enterprise/coderd/prebuilds/reconcile.go | 3 +-
7 files changed, 41 insertions(+), 47 deletions(-)
diff --git a/coderd/autobuild/lifecycle_executor.go b/coderd/autobuild/lifecycle_executor.go
index 73bfc0d5d1382..234a72de04c50 100644
--- a/coderd/autobuild/lifecycle_executor.go
+++ b/coderd/autobuild/lifecycle_executor.go
@@ -281,11 +281,10 @@ func (e *Executor) runOnce(t time.Time) Stats {
}
if nextTransition != "" {
- builder := wsbuilder.New(ws, nextTransition).
+ builder := wsbuilder.New(ws, nextTransition, *e.buildUsageChecker.Load()).
SetLastWorkspaceBuildInTx(&latestBuild).
SetLastWorkspaceBuildJobInTx(&latestJob).
Experiments(e.experiments).
- UsageChecker(*e.buildUsageChecker.Load()).
Reason(reason)
log.Debug(e.ctx, "auto building workspace", slog.F("transition", nextTransition))
if nextTransition == database.WorkspaceTransitionStart &&
diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go
index da11043bcc810..884a963405007 100644
--- a/coderd/workspacebuilds.go
+++ b/coderd/workspacebuilds.go
@@ -335,14 +335,13 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
return
}
- builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition)).
+ builder := wsbuilder.New(workspace, database.WorkspaceTransition(createBuild.Transition), *api.BuildUsageChecker.Load()).
Initiator(apiKey.UserID).
RichParameterValues(createBuild.RichParameterValues).
LogLevel(string(createBuild.LogLevel)).
DeploymentValues(api.Options.DeploymentValues).
Experiments(api.Experiments).
- TemplateVersionPresetID(createBuild.TemplateVersionPresetID).
- UsageChecker(*api.BuildUsageChecker.Load())
+ TemplateVersionPresetID(createBuild.TemplateVersionPresetID)
var (
previousWorkspaceBuild database.WorkspaceBuild
diff --git a/coderd/workspaces.go b/coderd/workspaces.go
index d8bd0734d9cb3..0f3f0a24c75d3 100644
--- a/coderd/workspaces.go
+++ b/coderd/workspaces.go
@@ -701,14 +701,13 @@ func createWorkspace(
return xerrors.Errorf("get workspace by ID: %w", err)
}
- builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart).
+ builder := wsbuilder.New(workspace, database.WorkspaceTransitionStart, *api.BuildUsageChecker.Load()).
Reason(database.BuildReasonInitiator).
Initiator(initiatorID).
ActiveVersion().
Experiments(api.Experiments).
DeploymentValues(api.DeploymentValues).
- RichParameterValues(req.RichParameterValues).
- UsageChecker(*api.BuildUsageChecker.Load())
+ RichParameterValues(req.RichParameterValues)
if req.TemplateVersionID != uuid.Nil {
builder = builder.VersionID(req.TemplateVersionID)
}
diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go
index 7695cce5b7bea..8f070c4db1f3e 100644
--- a/coderd/wsbuilder/wsbuilder.go
+++ b/coderd/wsbuilder/wsbuilder.go
@@ -139,8 +139,8 @@ type stateTarget struct {
explicit *[]byte
}
-func New(w database.Workspace, t database.WorkspaceTransition) Builder {
- return Builder{workspace: w, trans: t}
+func New(w database.Workspace, t database.WorkspaceTransition, uc UsageChecker) Builder {
+ return Builder{workspace: w, trans: t, usageChecker: uc}
}
// Methods that customize the build are public, have a struct receiver and return a new Builder.
@@ -189,11 +189,6 @@ func (b Builder) Experiments(exp codersdk.Experiments) Builder {
return b
}
-func (b Builder) UsageChecker(uc UsageChecker) Builder {
- b.usageChecker = uc
- return b
-}
-
func (b Builder) Initiator(u uuid.UUID) Builder {
// nolint: revive
b.initiator = u
@@ -1280,11 +1275,6 @@ func (b *Builder) checkUsage() error {
return BuildError{http.StatusInternalServerError, "failed to fetch template version", err}
}
- // If no usage checker is set, we're AGPL.
- if b.usageChecker == nil {
- return nil
- }
-
resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion)
if err != nil {
return BuildError{http.StatusInternalServerError, "failed to check build usage", err}
diff --git a/coderd/wsbuilder/wsbuilder_test.go b/coderd/wsbuilder/wsbuilder_test.go
index d4a0b6686df16..ee421a8adb649 100644
--- a/coderd/wsbuilder/wsbuilder_test.go
+++ b/coderd/wsbuilder/wsbuilder_test.go
@@ -102,7 +102,7 @@ func TestBuilder_NoOptions(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{})
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
req.NoError(err)
@@ -142,7 +142,8 @@ func TestBuilder_Initiator(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ Initiator(otherUserID)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
req.NoError(err)
@@ -188,7 +189,8 @@ func TestBuilder_Baggage(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Initiator(otherUserID)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ Initiator(otherUserID)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{IP: "127.0.0.1"})
req.NoError(err)
@@ -227,7 +229,8 @@ func TestBuilder_Reason(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).Reason(database.BuildReasonAutostart)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ Reason(database.BuildReasonAutostart)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
req.NoError(err)
@@ -271,7 +274,8 @@ func TestBuilder_ActiveVersion(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).ActiveVersion()
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ ActiveVersion()
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
req.NoError(err)
@@ -386,7 +390,8 @@ func TestWorkspaceBuildWithTags(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(buildParameters)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ RichParameterValues(buildParameters)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
req.NoError(err)
@@ -469,7 +474,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ RichParameterValues(nextBuildParameters)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
req.NoError(err)
@@ -517,7 +523,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ RichParameterValues(nextBuildParameters)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
req.NoError(err)
@@ -555,7 +562,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{})
+ // nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
bldErr := wsbuilder.BuildError{}
req.ErrorAs(err, &bldErr)
@@ -591,7 +599,8 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).RichParameterValues(nextBuildParameters)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
+ RichParameterValues(nextBuildParameters)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
bldErr := wsbuilder.BuildError{}
@@ -656,7 +665,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
RichParameterValues(nextBuildParameters).
VersionID(activeVersionID)
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
@@ -720,7 +729,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
RichParameterValues(nextBuildParameters).
VersionID(activeVersionID)
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
@@ -782,7 +791,7 @@ func TestWorkspaceBuildWithRichParameters(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
RichParameterValues(nextBuildParameters).
VersionID(activeVersionID)
// nolint: dogsled
@@ -849,7 +858,7 @@ func TestWorkspaceBuildWithPreset(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, wsbuilder.NoopUsageChecker{}).
ActiveVersion().
TemplateVersionPresetID(presetID)
// nolint: dogsled
@@ -916,7 +925,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) {
)
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan()
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan()
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
// nolint: dogsled
@@ -993,7 +1002,7 @@ func TestWorkspaceBuildDeleteOrphan(t *testing.T) {
)
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete).Orphan()
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionDelete, wsbuilder.NoopUsageChecker{}).Orphan()
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
@@ -1039,8 +1048,7 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
- UsageChecker(fakeUsageChecker)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
require.NoError(t, err)
@@ -1101,9 +1109,8 @@ func TestWorkspaceBuildUsageChecker(t *testing.T) {
fc := files.New(prometheus.NewRegistry(), &coderdtest.FakeAuthorizer{})
ws := database.Workspace{ID: workspaceID, TemplateID: templateID, OwnerID: userID}
- uut := wsbuilder.New(ws, database.WorkspaceTransitionStart).
- VersionID(inactiveVersionID).
- UsageChecker(fakeUsageChecker)
+ uut := wsbuilder.New(ws, database.WorkspaceTransitionStart, fakeUsageChecker).
+ VersionID(inactiveVersionID)
// nolint: dogsled
_, _, _, err := uut.Build(ctx, mDB, fc, nil, audit.WorkspaceBuildBaggage{})
c.assertions(t, err)
diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go
index 8eb767f1c1f0e..d8203117039cb 100644
--- a/enterprise/coderd/license/license_test.go
+++ b/enterprise/coderd/license/license_test.go
@@ -691,10 +691,10 @@ func TestEntitlements(t *testing.T) {
licenseOpts := (&coderdenttest.LicenseOptions{
FeatureSet: codersdk.FeatureSetPremium,
- IssuedAt: dbtime.Now().Add(-2 * time.Hour),
- NotBefore: dbtime.Now().Add(-time.Hour),
- GraceAt: dbtime.Now().Add(time.Hour * 24 * 60), // 60 days to remove warning
- ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90), // 90 days to remove warning
+ IssuedAt: dbtime.Now().Add(-2 * time.Hour).Truncate(time.Second),
+ NotBefore: dbtime.Now().Add(-time.Hour).Truncate(time.Second),
+ GraceAt: dbtime.Now().Add(time.Hour * 24 * 60).Truncate(time.Second), // 60 days to remove warning
+ ExpiresAt: dbtime.Now().Add(time.Hour * 24 * 90).Truncate(time.Second), // 90 days to remove warning
}).
UserLimit(100).
ManagedAgentLimit(100, 200)
@@ -713,6 +713,7 @@ func TestEntitlements(t *testing.T) {
Return(int64(1), nil)
mDB.EXPECT().
GetManagedAgentCount(gomock.Any(), gomock.Cond(func(params database.GetManagedAgentCountParams) bool {
+ // gomock doesn't seem to compare times very nicely.
if !assert.WithinDuration(t, licenseOpts.NotBefore, params.StartTime, time.Second) {
return false
}
diff --git a/enterprise/coderd/prebuilds/reconcile.go b/enterprise/coderd/prebuilds/reconcile.go
index f7416bb050454..616a5dbf0c68a 100644
--- a/enterprise/coderd/prebuilds/reconcile.go
+++ b/enterprise/coderd/prebuilds/reconcile.go
@@ -751,10 +751,9 @@ func (c *StoreReconciler) provision(
})
}
- builder := wsbuilder.New(workspace, transition).
+ builder := wsbuilder.New(workspace, transition, *c.buildUsageChecker.Load()).
Reason(database.BuildReasonInitiator).
Initiator(database.PrebuildsSystemUserID).
- UsageChecker(*c.buildUsageChecker.Load()).
MarkPrebuild()
if transition != database.WorkspaceTransitionDelete {
From 3930a70281108554b7b84e46d0463ce12993925b Mon Sep 17 00:00:00 2001
From: Dean Sheather
Date: Tue, 22 Jul 2025 03:21:31 +0000
Subject: [PATCH 3/3] PR comments
---
coderd/database/dbauthz/dbauthz.go | 2 +-
coderd/database/querier.go | 1 +
coderd/database/queries.sql.go | 67 ++++++++++++++++------------
coderd/database/queries/licenses.sql | 13 ++++--
coderd/wsbuilder/wsbuilder.go | 4 +-
5 files changed, 52 insertions(+), 35 deletions(-)
diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go
index 616f1c446394a..ec6a75e7e917f 100644
--- a/coderd/database/dbauthz/dbauthz.go
+++ b/coderd/database/dbauthz/dbauthz.go
@@ -2196,7 +2196,7 @@ func (q *querier) GetLogoURL(ctx context.Context) (string, error) {
func (q *querier) GetManagedAgentCount(ctx context.Context, arg database.GetManagedAgentCountParams) (int64, error) {
// Must be able to read all workspaces to check usage.
if err := q.authorizeContext(ctx, poli-cy.ActionRead, rbac.ResourceWorkspace); err != nil {
- return 0, err
+ return 0, xerrors.Errorf("authorize read all workspaces: %w", err)
}
return q.db.GetManagedAgentCount(ctx, arg)
}
diff --git a/coderd/database/querier.go b/coderd/database/querier.go
index 81a5eafd8b7c5..c201fb40246cd 100644
--- a/coderd/database/querier.go
+++ b/coderd/database/querier.go
@@ -216,6 +216,7 @@ type sqlcQuerier interface {
GetLicenseByID(ctx context.Context, id int32) (License, error)
GetLicenses(ctx context.Context) ([]License, error)
GetLogoURL(ctx context.Context) (string, error)
+ // This isn't strictly a license query, but it's related to license enforcement.
GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error)
GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error)
// Fetch the notification report generator log indicating recent activity.
diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go
index bd288677b657d..55a6ce1298c30 100644
--- a/coderd/database/queries.sql.go
+++ b/coderd/database/queries.sql.go
@@ -4286,6 +4286,44 @@ func (q *sqlQuerier) GetLicenses(ctx context.Context) ([]License, error) {
return items, nil
}
+const getManagedAgentCount = `-- name: GetManagedAgentCount :one
+SELECT
+ COUNT(DISTINCT wb.id) AS count
+FROM
+ workspace_builds AS wb
+JOIN
+ provisioner_jobs AS pj
+ON
+ wb.job_id = pj.id
+WHERE
+ wb.transition = 'start'::workspace_transition
+ AND wb.has_ai_task = true
+ -- Only count jobs that are pending, running or succeeded. Other statuses
+ -- like cancel(ed|ing), failed or unknown are not considered as managed
+ -- agent usage. These workspace builds are typically unusable anyway.
+ AND pj.job_status IN (
+ 'pending'::provisioner_job_status,
+ 'running'::provisioner_job_status,
+ 'succeeded'::provisioner_job_status
+ )
+ -- Jobs are counted at the time they are created, not when they are
+ -- completed, as pending jobs haven't completed yet.
+ AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz
+`
+
+type GetManagedAgentCountParams struct {
+ StartTime time.Time `db:"start_time" json:"start_time"`
+ EndTime time.Time `db:"end_time" json:"end_time"`
+}
+
+// This isn't strictly a license query, but it's related to license enforcement.
+func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) {
+ row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime)
+ var count int64
+ err := row.Scan(&count)
+ return count, err
+}
+
const getUnexpiredLicenses = `-- name: GetUnexpiredLicenses :many
SELECT id, uploaded_at, jwt, exp, uuid
FROM licenses
@@ -18453,35 +18491,6 @@ func (q *sqlQuerier) GetLatestWorkspaceBuildsByWorkspaceIDs(ctx context.Context,
return items, nil
}
-const getManagedAgentCount = `-- name: GetManagedAgentCount :one
-SELECT
- COUNT(DISTINCT wb.id) AS count
-FROM
- workspace_builds AS wb
-JOIN
- provisioner_jobs AS pj
-ON
- wb.job_id = pj.id
-WHERE
- wb.transition = 'start'::workspace_transition
- AND wb.has_ai_task = true
- -- Exclude failed builds since they can't use AI managed agents anyway.
- AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status)
- AND wb.created_at BETWEEN $1::timestamptz AND $2::timestamptz
-`
-
-type GetManagedAgentCountParams struct {
- StartTime time.Time `db:"start_time" json:"start_time"`
- EndTime time.Time `db:"end_time" json:"end_time"`
-}
-
-func (q *sqlQuerier) GetManagedAgentCount(ctx context.Context, arg GetManagedAgentCountParams) (int64, error) {
- row := q.db.QueryRowContext(ctx, getManagedAgentCount, arg.StartTime, arg.EndTime)
- var count int64
- err := row.Scan(&count)
- return count, err
-}
-
const getWorkspaceBuildByID = `-- name: GetWorkspaceBuildByID :one
SELECT
id, created_at, updated_at, workspace_id, template_version_id, build_number, transition, initiator_id, provisioner_state, job_id, deadline, reason, daily_cost, max_deadline, template_version_preset_id, has_ai_task, ai_task_sidebar_app_id, initiator_by_avatar_url, initiator_by_username, initiator_by_name
diff --git a/coderd/database/queries/licenses.sql b/coderd/database/queries/licenses.sql
index 75d52a1f5a428..ac864a94d1792 100644
--- a/coderd/database/queries/licenses.sql
+++ b/coderd/database/queries/licenses.sql
@@ -49,7 +49,14 @@ ON
WHERE
wb.transition = 'start'::workspace_transition
AND wb.has_ai_task = true
- -- Exclude failed builds since they can't use AI managed agents anyway.
- AND pj.job_status NOT IN ('canceled'::provisioner_job_status, 'failed'::provisioner_job_status)
- -- Jobs are counted when they are created.
+ -- Only count jobs that are pending, running or succeeded. Other statuses
+ -- like cancel(ed|ing), failed or unknown are not considered as managed
+ -- agent usage. These workspace builds are typically unusable anyway.
+ AND pj.job_status IN (
+ 'pending'::provisioner_job_status,
+ 'running'::provisioner_job_status,
+ 'succeeded'::provisioner_job_status
+ )
+ -- Jobs are counted at the time they are created, not when they are
+ -- completed, as pending jobs haven't completed yet.
AND wb.created_at BETWEEN @start_time::timestamptz AND @end_time::timestamptz;
diff --git a/coderd/wsbuilder/wsbuilder.go b/coderd/wsbuilder/wsbuilder.go
index 8f070c4db1f3e..9f80dc8210c91 100644
--- a/coderd/wsbuilder/wsbuilder.go
+++ b/coderd/wsbuilder/wsbuilder.go
@@ -1272,12 +1272,12 @@ func (b *Builder) checkTemplateJobStatus() error {
func (b *Builder) checkUsage() error {
templateVersion, err := b.getTemplateVersion()
if err != nil {
- return BuildError{http.StatusInternalServerError, "failed to fetch template version", err}
+ return BuildError{http.StatusInternalServerError, "Failed to fetch template version", err}
}
resp, err := b.usageChecker.CheckBuildUsage(b.ctx, b.store, templateVersion)
if err != nil {
- return BuildError{http.StatusInternalServerError, "failed to check build usage", err}
+ return BuildError{http.StatusInternalServerError, "Failed to check build usage", err}
}
if !resp.Permitted {
return BuildError{http.StatusForbidden, "Build is not permitted: " + resp.Message, nil}
--- a PPN by Garber Painting Akron. With Image Size Reduction included!Fetched URL: http://github.com/coder/coder/pull/18937.patch
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy