Skip to content

Commit 332f8da

Browse files
committed
chore: populate connectionlog count using a separate query
1 parent cd30512 commit 332f8da

File tree

16 files changed

+689
-11
lines changed

16 files changed

+689
-11
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,21 @@ func (q *querier) CleanTailnetTunnels(ctx context.Context) error {
13231323
return q.db.CleanTailnetTunnels(ctx)
13241324
}
13251325

1326+
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1327+
// Just like the actual query, shortcut if the user is an owner.
1328+
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
1329+
if err == nil {
1330+
return q.db.CountConnectionLogs(ctx, arg)
1331+
}
1332+
1333+
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type)
1334+
if err != nil {
1335+
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1336+
}
1337+
1338+
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
1339+
}
1340+
13261341
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
13271342
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
13281343
return nil, err
@@ -5301,3 +5316,7 @@ func (q *querier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg database
53015316
func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
53025317
return q.GetConnectionLogsOffset(ctx, arg)
53035318
}
5319+
5320+
func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
5321+
return q.CountConnectionLogs(ctx, arg)
5322+
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,42 @@ func (s *MethodTestSuite) TestConnectionLogs() {
391391
LimitOpt: 10,
392392
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
393393
}))
394+
s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
395+
ws := createWorkspace(s.T(), db)
396+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
397+
Type: database.ConnectionTypeSsh,
398+
WorkspaceID: ws.ID,
399+
OrganizationID: ws.OrganizationID,
400+
WorkspaceOwnerID: ws.OwnerID,
401+
})
402+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
403+
Type: database.ConnectionTypeSsh,
404+
WorkspaceID: ws.ID,
405+
OrganizationID: ws.OrganizationID,
406+
WorkspaceOwnerID: ws.OwnerID,
407+
})
408+
check.Args(database.CountConnectionLogsParams{}).Asserts(
409+
rbac.ResourceConnectionLog, policy.ActionRead,
410+
).WithNotAuthorized("nil")
411+
}))
412+
s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
413+
ws := createWorkspace(s.T(), db)
414+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
415+
Type: database.ConnectionTypeSsh,
416+
WorkspaceID: ws.ID,
417+
OrganizationID: ws.OrganizationID,
418+
WorkspaceOwnerID: ws.OwnerID,
419+
})
420+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
421+
Type: database.ConnectionTypeSsh,
422+
WorkspaceID: ws.ID,
423+
OrganizationID: ws.OrganizationID,
424+
WorkspaceOwnerID: ws.OwnerID,
425+
})
426+
check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts(
427+
rbac.ResourceConnectionLog, policy.ActionRead,
428+
)
429+
}))
394430
}
395431

396432
func (s *MethodTestSuite) TestFile() {

coderd/database/dbauthz/setup_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
271271

272272
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
273273
// any case where the error is nil and the response is an empty slice.
274-
if err != nil || !hasEmptySliceResponse(resp) {
274+
if err != nil || !hasEmptyResponse(resp) {
275275
// Expect the default error
276276
if testCase.notAuthorizedExpect == "" {
277277
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
@@ -297,7 +297,7 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
297297

298298
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
299299
// any case where the error is nil and the response is an empty slice.
300-
if err != nil || !hasEmptySliceResponse(resp) {
300+
if err != nil || !hasEmptyResponse(resp) {
301301
if testCase.cancelledCtxExpect == "" {
302302
s.Errorf(err, "method should an error with cancellation")
303303
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
@@ -308,13 +308,20 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
308308
})
309309
}
310310

311-
func hasEmptySliceResponse(values []reflect.Value) bool {
311+
func hasEmptyResponse(values []reflect.Value) bool {
312312
for _, r := range values {
313313
if r.Kind() == reflect.Slice || r.Kind() == reflect.Array {
314314
if r.Len() == 0 {
315315
return true
316316
}
317317
}
318+
319+
// Special case for int64, as it's the return type for count queries.
320+
if r.Kind() == reflect.Int64 {
321+
if r.Int() == 0 {
322+
return true
323+
}
324+
}
318325
}
319326
return false
320327
}

coderd/database/dbmem/dbmem.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,6 +1780,10 @@ func (*FakeQuerier) CleanTailnetTunnels(context.Context) error {
17801780
return ErrUnimplemented
17811781
}
17821782

1783+
func (q *FakeQuerier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1784+
return q.CountAuthorizedConnectionLogs(ctx, arg, nil)
1785+
}
1786+
17831787
func (q *FakeQuerier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
17841788
return nil, ErrUnimplemented
17851789
}
@@ -14170,3 +14174,100 @@ func (q *FakeQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
1417014174

1417114175
return logs, nil
1417214176
}
14177+
14178+
func (q *FakeQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
14179+
if err := validateDatabaseType(arg); err != nil {
14180+
return 0, err
14181+
}
14182+
14183+
// Call this to match the same function calls as the SQL implementation.
14184+
// It functionally does nothing for filtering.
14185+
if prepared != nil {
14186+
_, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
14187+
VariableConverter: regosql.ConnectionLogConverter(),
14188+
})
14189+
if err != nil {
14190+
return 0, err
14191+
}
14192+
}
14193+
14194+
q.mutex.RLock()
14195+
defer q.mutex.RUnlock()
14196+
14197+
var count int64
14198+
14199+
for _, clog := range q.connectionLogs {
14200+
if arg.OrganizationID != uuid.Nil && clog.OrganizationID != arg.OrganizationID {
14201+
continue
14202+
}
14203+
if arg.WorkspaceOwner != "" {
14204+
workspaceOwner, err := q.getUserByIDNoLock(clog.WorkspaceOwnerID)
14205+
if err == nil && !strings.EqualFold(arg.WorkspaceOwner, workspaceOwner.Username) {
14206+
continue
14207+
}
14208+
}
14209+
if arg.WorkspaceOwnerID != uuid.Nil && clog.WorkspaceOwnerID != arg.WorkspaceOwnerID {
14210+
continue
14211+
}
14212+
if arg.WorkspaceOwnerEmail != "" {
14213+
workspaceOwner, err := q.getUserByIDNoLock(clog.WorkspaceOwnerID)
14214+
if err != nil || workspaceOwner.Email != arg.WorkspaceOwnerEmail {
14215+
continue
14216+
}
14217+
}
14218+
if arg.Type != "" && string(clog.Type) != arg.Type {
14219+
continue
14220+
}
14221+
if arg.UserID != uuid.Nil && (!clog.UserID.Valid || clog.UserID.UUID != arg.UserID) {
14222+
continue
14223+
}
14224+
if arg.Username != "" {
14225+
if !clog.UserID.Valid {
14226+
continue
14227+
}
14228+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14229+
if err != nil || user.Username != arg.Username {
14230+
continue
14231+
}
14232+
}
14233+
if arg.UserEmail != "" {
14234+
if !clog.UserID.Valid {
14235+
continue
14236+
}
14237+
user, err := q.getUserByIDNoLock(clog.UserID.UUID)
14238+
if err != nil || user.Email != arg.UserEmail {
14239+
continue
14240+
}
14241+
}
14242+
if !arg.StartedAfter.IsZero() && clog.Time.Before(arg.StartedAfter) {
14243+
continue
14244+
}
14245+
if !arg.StartedBefore.IsZero() && clog.Time.After(arg.StartedBefore) {
14246+
continue
14247+
}
14248+
if arg.WorkspaceID != uuid.Nil && clog.WorkspaceID != arg.WorkspaceID {
14249+
continue
14250+
}
14251+
if arg.ConnectionID != uuid.Nil && (!clog.ConnectionID.Valid || clog.ConnectionID.UUID != arg.ConnectionID) {
14252+
continue
14253+
}
14254+
if arg.Status != "" {
14255+
if clog.Type == database.ConnectionTypeWorkspaceApp ||
14256+
clog.Type == database.ConnectionTypePortForwarding {
14257+
continue
14258+
}
14259+
isConnected := !clog.CloseTime.Valid
14260+
if (arg.Status == string(codersdk.ConnectionLogStatusOngoing) && !isConnected) || (arg.Status == string(codersdk.ConnectionLogStatusCompleted) && isConnected) {
14261+
continue
14262+
}
14263+
}
14264+
14265+
if prepared != nil && prepared.Authorize(ctx, clog.RBACObject()) != nil {
14266+
continue
14267+
}
14268+
14269+
count++
14270+
}
14271+
14272+
return count, nil
14273+
}

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ func (q *sqlQuerier) GetAuthorizedAuditLogsOffset(ctx context.Context, arg GetAu
566566

567567
type connectionLogQuerier interface {
568568
GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error)
569+
CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error)
569570
}
570571

571572
func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) {
@@ -653,6 +654,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
653654
return items, nil
654655
}
655656

657+
func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
658+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
659+
VariableConverter: regosql.ConnectionLogConverter(),
660+
})
661+
if err != nil {
662+
return 0, xerrors.Errorf("compile authorized filter: %w", err)
663+
}
664+
filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter))
665+
if err != nil {
666+
return 0, xerrors.Errorf("insert authorized filter: %w", err)
667+
}
668+
669+
query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered)
670+
rows, err := q.db.QueryContext(ctx, query,
671+
arg.OrganizationID,
672+
arg.WorkspaceOwner,
673+
arg.WorkspaceOwnerID,
674+
arg.WorkspaceOwnerEmail,
675+
arg.Type,
676+
arg.UserID,
677+
arg.Username,
678+
arg.UserEmail,
679+
arg.StartedAfter,
680+
arg.StartedBefore,
681+
arg.WorkspaceID,
682+
arg.ConnectionID,
683+
arg.Status,
684+
)
685+
if err != nil {
686+
return 0, err
687+
}
688+
defer rows.Close()
689+
var count int64
690+
for rows.Next() {
691+
if err := rows.Scan(&count); err != nil {
692+
return 0, err
693+
}
694+
}
695+
if err := rows.Close(); err != nil {
696+
return 0, err
697+
}
698+
if err := rows.Err(); err != nil {
699+
return 0, err
700+
}
701+
return count, nil
702+
}
703+
656704
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
657705
if !strings.Contains(query, authorizedQueryPlaceholder) {
658706
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")

coderd/database/modelqueries_internal_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package database
22

33
import (
4+
"regexp"
5+
"strings"
46
"testing"
57
"time"
68

@@ -54,3 +56,35 @@ func TestWorkspaceTableConvert(t *testing.T) {
5456
"'workspace.WorkspaceTable()' is not missing at least 1 field when converting to 'WorkspaceTable'. "+
5557
"To resolve this, go to the 'func (w Workspace) WorkspaceTable()' and ensure all fields are converted.")
5658
}
59+
60+
func TestConnectionLogsQueryConsistency(t *testing.T) {
61+
t.Parallel()
62+
63+
getWhereClause := extractWhereClause(getConnectionLogsOffset)
64+
require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause")
65+
66+
countWhereClause := extractWhereClause(countConnectionLogs)
67+
require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause")
68+
69+
require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause")
70+
}
71+
72+
// extractWhereClause extracts the WHERE clause from a SQL query string
73+
func extractWhereClause(query string) string {
74+
// Find WHERE and get everything after it
75+
wherePattern := regexp.MustCompile(`(?is)WHERE\s+(.*)`)
76+
whereMatches := wherePattern.FindStringSubmatch(query)
77+
if len(whereMatches) < 2 {
78+
return ""
79+
}
80+
81+
whereClause := whereMatches[1]
82+
83+
// Remove ORDER BY, LIMIT, OFFSET clauses from the end
84+
whereClause = regexp.MustCompile(`(?is)\s+(ORDER BY|LIMIT|OFFSET).*$`).ReplaceAllString(whereClause, "")
85+
86+
// Remove SQL comments
87+
whereClause = regexp.MustCompile(`(?m)--.*$`).ReplaceAllString(whereClause, "")
88+
89+
return strings.TrimSpace(whereClause)
90+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy