Skip to content

Commit d03fe73

Browse files
committed
chore: populate connectionlog count using a separate query
1 parent 8cec6d1 commit d03fe73

File tree

14 files changed

+663
-5
lines changed

14 files changed

+663
-5
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,21 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
13231323
return q.db.CleanTailnetTunnels(ctx)
13241324
}
13251325

1326+
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1327+
// Just like the actual query, shortcut if the user is an owner.
1328+
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
1329+
if err == nil {
1330+
return q.db.CountConnectionLogs(ctx, arg)
1331+
}
1332+
1333+
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type)
1334+
if err != nil {
1335+
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1336+
}
1337+
1338+
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
1339+
}
1340+
13261341
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
13271342
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
13281343
return nil, err
@@ -5301,3 +5316,7 @@ func (q *querier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database
53015316
func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
53025317
return q.GetConnectionLogsOffset(ctx, arg)
53035318
}
5319+
5320+
func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
5321+
return q.CountConnectionLogs(ctx, arg)
5322+
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,42 @@ func (s *MethodTestSuite) TestConnectionLogs() {
391391
LimitOpt: 10,
392392
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
393393
}))
394+
s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
395+
ws := createWorkspace(s.T(), db)
396+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
397+
Type: database.ConnectionTypeSsh,
398+
WorkspaceID: ws.ID,
399+
OrganizationID: ws.OrganizationID,
400+
WorkspaceOwnerID: ws.OwnerID,
401+
})
402+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
403+
Type: database.ConnectionTypeSsh,
404+
WorkspaceID: ws.ID,
405+
OrganizationID: ws.OrganizationID,
406+
WorkspaceOwnerID: ws.OwnerID,
407+
})
408+
check.Args(database.CountConnectionLogsParams{}).Asserts(
409+
rbac.ResourceConnectionLog, policy.ActionRead,
410+
).WithNotAuthorized("nil")
411+
}))
412+
s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
413+
ws := createWorkspace(s.T(), db)
414+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
415+
Type: database.ConnectionTypeSsh,
416+
WorkspaceID: ws.ID,
417+
OrganizationID: ws.OrganizationID,
418+
WorkspaceOwnerID: ws.OwnerID,
419+
})
420+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
421+
Type: database.ConnectionTypeSsh,
422+
WorkspaceID: ws.ID,
423+
OrganizationID: ws.OrganizationID,
424+
WorkspaceOwnerID: ws.OwnerID,
425+
})
426+
check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts(
427+
rbac.ResourceConnectionLog, policy.ActionRead,
428+
)
429+
}))
394430
}
395431

396432
func (s *MethodTestSuite) TestFile() {

coderd/database/dbauthz/setup_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
271271

272272
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
273273
// any case where the error is nil and the response is an empty slice.
274-
if err != nil || !hasEmptySliceResponse(resp) {
274+
if err != nil || !hasEmptyResponse(resp) {
275275
// Expect the default error
276276
if testCase.notAuthorizedExpect == "" {
277277
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
@@ -297,7 +297,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
297297

298298
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
299299
// any case where the error is nil and the response is an empty slice.
300-
if err != nil || !hasEmptySliceResponse(resp) {
300+
if err != nil || !hasEmptyResponse(resp) {
301301
if testCase.cancelledCtxExpect == "" {
302302
s.Errorf(err, "method should an error with cancellation")
303303
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
@@ -308,13 +308,20 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
308308
})
309309
}
310310

311-
func hasEmptySliceResponse(values []reflect.Value) bool {
311+
func hasEmptyResponse(values []reflect.Value) bool {
312312
for _, r := range values {
313313
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
314314
if r.Len() == 0 {
315315
return true
316316
}
317317
}
318+
319+
// Special case for int64, as it's the return type for count queries.
320+
if r.Kind() == reflect.Int64 {
321+
if r.Int() == 0 {
322+
return true
323+
}
324+
}
318325
}
319326
return false
320327
}

coderd/database/dbmem/dbmem.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,10 @@ func (*FakeQuerier) CleanTailnetTunnels(context.Context) error {
17801780
return ErrUnimplemented
17811781
}
17821782

1783+
func (q *FakeQuerier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1784+
return q.CountAuthorizedConnectionLogs(ctx, arg, nil)
1785+
}
1786+
17831787
func (q *FakeQuerier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
17841788
return nil, ErrUnimplemented
17851789
}
@@ -14156,3 +14160,93 @@ func (q *FakeQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
1415614160

1415714161
return logs, nil
1415814162
}
14163+
14164+
func (q *FakeQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
14165+
if err := validateDatabaseType(arg); err != nil {
14166+
return 0, err
14167+
}
14168+
14169+
// Call this to match the same function calls as the SQL implementation.
14170+
// It functionally does nothing for filtering.
14171+
if prepared != nil {
14172+
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
14173+
VariableConverter: regosql.ConnectionLogConverter(),
14174+
})
14175+
if err != nil {
14176+
return 0, err
14177+
}
14178+
}
14179+
14180+
q.mutex.RLock()
14181+
defer q.mutex.RUnlock()
14182+
14183+
var count int64
14184+
14185+
for _, clog := range q.connectionLogs {
14186+
if arg.OrganizationID != uuid.Nil && clog.OrganizationID != arg.OrganizationID {
14187+
continue
14188+
}
14189+
if arg.WorkspaceOwner != "" {
14190+
workspaceOwner, err := q.getUserByIDNoLock(clog.WorkspaceOwnerID)
14191+
if err == nil && !strings.EqualFold(arg.WorkspaceOwner, workspaceOwner.Username) {
14192+
continue
14193+
}
14194+
}
14195+
if arg.Type != "" && string(clog.Type) != arg.Type {
14196+
continue
14197+
}
14198+
if arg.UserID != uuid.Nil && (!clog.UserID.Valid || clog.UserID.UUID != arg.UserID) {
14199+
continue
14200+
}
14201+
if arg.Username != "" {
14202+
if !clog.UserID.Valid {
14203+
continue
14204+
}
14205+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14206+
if err != nil || user.Username != arg.Username {
14207+
continue
14208+
}
14209+
}
14210+
if arg.Email != "" {
14211+
if !clog.UserID.Valid {
14212+
continue
14213+
}
14214+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14215+
if err != nil || user.Email != arg.Email {
14216+
continue
14217+
}
14218+
}
14219+
if !arg.StartedAfter.IsZero() && clog.Time.Before(arg.StartedAfter) {
14220+
continue
14221+
}
14222+
if !arg.StartedBefore.IsZero() && clog.Time.After(arg.StartedBefore) {
14223+
continue
14224+
}
14225+
if !arg.ClosedAfter.IsZero() && (!clog.CloseTime.Valid || clog.CloseTime.Time.Before(arg.ClosedAfter)) {
14226+
continue
14227+
}
14228+
if !arg.ClosedBefore.IsZero() && (!clog.CloseTime.Valid || clog.CloseTime.Time.After(arg.ClosedBefore)) {
14229+
continue
14230+
}
14231+
if arg.WorkspaceID != uuid.Nil && clog.WorkspaceID != arg.WorkspaceID {
14232+
continue
14233+
}
14234+
if arg.ConnectionID != uuid.Nil && (!clog.ConnectionID.Valid || clog.ConnectionID.UUID != arg.ConnectionID) {
14235+
continue
14236+
}
14237+
if arg.Status != "" {
14238+
isConnected := !clog.CloseTime.Valid
14239+
if (arg.Status == "connected" && !isConnected) || (arg.Status == "disconnected" && isConnected) {
14240+
continue
14241+
}
14242+
}
14243+
14244+
if prepared != nil && prepared.Authorize(ctx, clog.RBACObject()) != nil {
14245+
continue
14246+
}
14247+
14248+
count++
14249+
}
14250+
14251+
return count, nil
14252+
}

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ func (q *sqlQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAu
566566

567567
type connectionLogQuerier interface {
568568
GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error)
569+
CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error)
569570
}
570571

571572
func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) {
@@ -653,6 +654,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
653654
return items, nil
654655
}
655656

657+
func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
658+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
659+
VariableConverter: regosql.ConnectionLogConverter(),
660+
})
661+
if err != nil {
662+
return 0, xerrors.Errorf("compile authorized filter: %w", err)
663+
}
664+
filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter))
665+
if err != nil {
666+
return 0, xerrors.Errorf("insert authorized filter: %w", err)
667+
}
668+
669+
query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered)
670+
rows, err := q.db.QueryContext(ctx, query,
671+
arg.OrganizationID,
672+
arg.WorkspaceOwner,
673+
arg.Type,
674+
arg.UserID,
675+
arg.Username,
676+
arg.Email,
677+
arg.StartedAfter,
678+
arg.StartedBefore,
679+
arg.ClosedAfter,
680+
arg.ClosedBefore,
681+
arg.WorkspaceID,
682+
arg.ConnectionID,
683+
arg.Status,
684+
)
685+
if err != nil {
686+
return 0, err
687+
}
688+
defer rows.Close()
689+
var count int64
690+
for rows.Next() {
691+
if err := rows.Scan(&count); err != nil {
692+
return 0, err
693+
}
694+
}
695+
if err := rows.Close(); err != nil {
696+
return 0, err
697+
}
698+
if err := rows.Err(); err != nil {
699+
return 0, err
700+
}
701+
return count, nil
702+
}
703+
656704
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
657705
if !strings.Contains(query, authorizedQueryPlaceholder) {
658706
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")

coderd/database/modelqueries_internal_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package database
22

33
import (
4+
"regexp"
5+
"strings"
46
"testing"
57
"time"
68

@@ -54,3 +56,35 @@ func TestWorkspaceTableConvert(t *testing.T) {
5456
"'workspace.WorkspaceTable()' is not missing at least 1 field when converting to 'WorkspaceTable'. "+
5557
"To resolve this, go to the 'func (w Workspace) WorkspaceTable()' and ensure all fields are converted.")
5658
}
59+
60+
func TestConnectionLogsQueryConsistency(t *testing.T) {
61+
t.Parallel()
62+
63+
getWhereClause := extractWhereClause(getConnectionLogsOffset)
64+
require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause")
65+
66+
countWhereClause := extractWhereClause(countConnectionLogs)
67+
require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause")
68+
69+
require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause")
70+
}
71+
72+
// extractWhereClause extracts the WHERE clause from a SQL query string
73+
func extractWhereClause(query string) string {
74+
// Find WHERE and get everything after it
75+
wherePattern := regexp.MustCompile(`(?is)WHERE\s+(.*)`)
76+
whereMatches := wherePattern.FindStringSubmatch(query)
77+
if len(whereMatches) < 2 {
78+
return ""
79+
}
80+
81+
whereClause := whereMatches[1]
82+
83+
// Remove ORDER BY, LIMIT, OFFSET clauses from the end
84+
whereClause = regexp.MustCompile(`(?is)\s+(ORDER BY|LIMIT|OFFSET).*$`).ReplaceAllString(whereClause, "")
85+
86+
// Remove SQL comments
87+
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
88+
89+
return strings.TrimSpace(whereClause)
90+
}

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy