diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a1c758ce03415..8b616f34b8441 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1353,15 +1353,26 @@ func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLog if err == nil { return q.db.CountAuditLogs(ctx, arg) } - prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAuditLog.Type) if err != nil { return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err) } - return q.db.CountAuthorizedAuditLogs(ctx, arg, prep) } +func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) { + // Just like the actual query, shortcut if the user is an owner. + err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog) + if err == nil { + return q.db.CountConnectionLogs(ctx, arg) + } + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type) + if err != nil { + return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep) +} + func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil { return nil, err @@ -5392,3 +5403,7 @@ func (q *querier) CountAuthorizedAuditLogs(ctx context.Context, arg database.Cou func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) { return q.GetConnectionLogsOffset(ctx, arg) } + +func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) { + return q.CountConnectionLogs(ctx, arg) +} diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 5416f33e521ec..2ea27f7d92342 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -406,6 +406,42 @@ func (s *MethodTestSuite) TestConnectionLogs() { LimitOpt: 10, }, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead) })) + s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) { + ws := createWorkspace(s.T(), db) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + check.Args(database.CountConnectionLogsParams{}).Asserts( + rbac.ResourceConnectionLog, policy.ActionRead, + ).WithNotAuthorized("nil") + })) + s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) { + ws := createWorkspace(s.T(), db) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts( + rbac.ResourceConnectionLog, policy.ActionRead, + ) + })) } func (s *MethodTestSuite) TestFile() { diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index 23effafc632e0..d4dacb78a4d50 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -318,7 +318,7 @@ func hasEmptyResponse(values []reflect.Value) bool { } } - // Special case for int64, as it's the return type for count query. + // Special case for int64, as it's the return type for count queries. if r.Kind() == reflect.Int64 { if r.Int() == 0 { return true diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e353a4688281d..a0090a1103279 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -194,6 +194,13 @@ func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.Coun return r0, r1 } +func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountConnectionLogs(ctx, arg) + m.queryLatencies.WithLabelValues("CountConnectionLogs").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { start := time.Now() r0, r1 := m.s.CountInProgressPrebuilds(ctx) @@ -3413,3 +3420,10 @@ func (m queryMetricsStore) GetAuthorizedConnectionLogsOffset(ctx context.Context m.queryLatencies.WithLabelValues("GetAuthorizedConnectionLogsOffset").Observe(time.Since(start).Seconds()) return r0, r1 } + +func (m queryMetricsStore) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAuthorizedConnectionLogs(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("CountAuthorizedConnectionLogs").Observe(time.Since(start).Seconds()) + return r0, r1 +} diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 14e5344325b9b..723c4f3687e81 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -278,6 +278,36 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAuditLogs(ctx, arg, prepared any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAuditLogs", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAuditLogs), ctx, arg, prepared) } +// CountAuthorizedConnectionLogs mocks base method. +func (m *MockStore) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAuthorizedConnectionLogs", ctx, arg, prepared) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAuthorizedConnectionLogs indicates an expected call of CountAuthorizedConnectionLogs. +func (mr *MockStoreMockRecorder) CountAuthorizedConnectionLogs(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountAuthorizedConnectionLogs), ctx, arg, prepared) +} + +// CountConnectionLogs mocks base method. +func (m *MockStore) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountConnectionLogs", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountConnectionLogs indicates an expected call of CountConnectionLogs. +func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg) +} + // CountInProgressPrebuilds mocks base method. func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 193ac3daa46bf..6bb7483847a2e 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -614,6 +614,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi type connectionLogQuerier interface { GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) + CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) } func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) { @@ -700,6 +701,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg return items, nil } +func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.ConnectionLogConverter(), + }) + if err != nil { + return 0, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return 0, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.OrganizationID, + arg.WorkspaceOwner, + arg.WorkspaceOwnerID, + arg.WorkspaceOwnerEmail, + arg.Type, + arg.UserID, + arg.Username, + arg.UserEmail, + arg.ConnectedAfter, + arg.ConnectedBefore, + arg.WorkspaceID, + arg.ConnectionID, + arg.Status, + ) + if err != nil { + return 0, err + } + defer rows.Close() + var count int64 + for rows.Next() { + if err := rows.Scan(&count); err != nil { + return 0, err + } + } + if err := rows.Close(); err != nil { + return 0, err + } + if err := rows.Err(); err != nil { + return 0, err + } + return count, nil +} + func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") diff --git a/coderd/database/modelqueries_internal_test.go b/coderd/database/modelqueries_internal_test.go index 4f675a1b60785..275ed947a3e4c 100644 --- a/coderd/database/modelqueries_internal_test.go +++ b/coderd/database/modelqueries_internal_test.go @@ -76,6 +76,19 @@ func TestAuditLogsQueryConsistency(t *testing.T) { } } +// Same as TestAuditLogsQueryConsistency, but for connection logs. +func TestConnectionLogsQueryConsistency(t *testing.T) { + t.Parallel() + + getWhereClause := extractWhereClause(getConnectionLogsOffset) + require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause") + + countWhereClause := extractWhereClause(countConnectionLogs) + require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause") + + require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause") +} + // extractWhereClause extracts the WHERE clause from a SQL query string func extractWhereClause(query string) string { // Find WHERE and get everything after it diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 8af37596cb5c6..72f511618838b 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -66,6 +66,7 @@ type sqlcQuerier interface { CleanTailnetLostPeers(ctx context.Context) error CleanTailnetTunnels(ctx context.Context) error CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) + CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) // CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition. // Prebuild considered in-progress if it's in the "starting", "stopping", or "deleting" state. CountInProgressPrebuilds(ctx context.Context) ([]CountInProgressPrebuildsRow, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index a3d48e46b4fe7..20b07450364af 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -2168,6 +2168,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: No logs returned require.Len(t, logs, 0, "no logs should be returned") + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(memberCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("SiteWideAuditor", func(t *testing.T) { @@ -2186,6 +2190,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: All logs are returned require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs)) + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(siteAuditorCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("SingleOrgAuditor", func(t *testing.T) { @@ -2205,6 +2213,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: Only the logs for the organization are returned require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs)) + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(orgAuditCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("TwoOrgAuditors", func(t *testing.T) { @@ -2225,6 +2237,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: All logs for both organizations are returned require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs)) + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(multiOrgAuditCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("ErroneousOrg", func(t *testing.T) { @@ -2243,9 +2259,71 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: No logs are returned require.Len(t, logs, 0, "no logs should be returned") + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(userCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) } +func TestCountConnectionLogs(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + db, _ := dbtestutil.NewDB(t) + + orgA := dbfake.Organization(t, db).Do() + userA := dbgen.User(t, db, database.User{}) + tplA := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: userA.ID}) + wsA := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userA.ID, OrganizationID: orgA.Org.ID, TemplateID: tplA.ID}) + + orgB := dbfake.Organization(t, db).Do() + userB := dbgen.User(t, db, database.User{}) + tplB := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: userB.ID}) + wsB := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userB.ID, OrganizationID: orgB.Org.ID, TemplateID: tplB.ID}) + + // Create logs for two different orgs. + for i := 0; i < 20; i++ { + dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{ + OrganizationID: wsA.OrganizationID, + WorkspaceOwnerID: wsA.OwnerID, + WorkspaceID: wsA.ID, + Type: database.ConnectionTypeSsh, + }) + } + for i := 0; i < 10; i++ { + dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{ + OrganizationID: wsB.OrganizationID, + WorkspaceOwnerID: wsB.OwnerID, + WorkspaceID: wsB.ID, + Type: database.ConnectionTypeSsh, + }) + } + + // Count with a filter for orgA. + countParams := database.CountConnectionLogsParams{ + OrganizationID: orgA.Org.ID, + } + totalCount, err := db.CountConnectionLogs(ctx, countParams) + require.NoError(t, err) + require.Equal(t, int64(20), totalCount) + + // Get a paginated result for the same filter. + getParams := database.GetConnectionLogsOffsetParams{ + OrganizationID: orgA.Org.ID, + LimitOpt: 5, + OffsetOpt: 10, + } + logs, err := db.GetConnectionLogsOffset(ctx, getParams) + require.NoError(t, err) + require.Len(t, logs, 5) + + // The count with the filter should remain the same, independent of pagination. + countAfterGet, err := db.CountConnectionLogs(ctx, countParams) + require.NoError(t, err) + require.Equal(t, int64(20), countAfterGet) +} + func TestConnectionLogsOffsetFilters(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) @@ -2484,7 +2562,24 @@ func TestConnectionLogsOffsetFilters(t *testing.T) { t.Parallel() logs, err := db.GetConnectionLogsOffset(ctx, tc.params) require.NoError(t, err) + count, err := db.CountConnectionLogs(ctx, database.CountConnectionLogsParams{ + OrganizationID: tc.params.OrganizationID, + WorkspaceOwner: tc.params.WorkspaceOwner, + Type: tc.params.Type, + UserID: tc.params.UserID, + Username: tc.params.Username, + UserEmail: tc.params.UserEmail, + ConnectedAfter: tc.params.ConnectedAfter, + ConnectedBefore: tc.params.ConnectedBefore, + WorkspaceID: tc.params.WorkspaceID, + ConnectionID: tc.params.ConnectionID, + Status: tc.params.Status, + WorkspaceOwnerID: tc.params.WorkspaceOwnerID, + WorkspaceOwnerEmail: tc.params.WorkspaceOwnerEmail, + }) + require.NoError(t, err) require.ElementsMatch(t, tc.expectedLogIDs, connectionOnlyIDs(logs)) + require.Equal(t, len(tc.expectedLogIDs), int(count), "CountConnectionLogs should match the number of returned logs (no offset or limit)") }) } } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index cef983eb0f1b9..676ce75621ded 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -880,6 +880,150 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const countConnectionLogs = `-- name: CountConnectionLogs :one +SELECT + COUNT(*) AS count +FROM + connection_logs +JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id +LEFT JOIN users ON + connection_logs.user_id = users.id +JOIN organizations ON + connection_logs.organization_id = organizations.id +WHERE + -- Filter organization_id + CASE + WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.organization_id = $1 + ELSE true + END + -- Filter by workspace owner username + AND CASE + WHEN $2 :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE lower(username) = lower($2) AND deleted = false + ) + ELSE true + END + -- Filter by workspace_owner_id + AND CASE + WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + workspace_owner_id = $3 + ELSE true + END + -- Filter by workspace_owner_email + AND CASE + WHEN $4 :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE email = $4 AND deleted = false + ) + ELSE true + END + -- Filter by type + AND CASE + WHEN $5 :: text != '' THEN + type = $5 :: connection_type + ELSE true + END + -- Filter by user_id + AND CASE + WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = $6 + ELSE true + END + -- Filter by username + AND CASE + WHEN $7 :: text != '' THEN + user_id = ( + SELECT id FROM users + WHERE lower(username) = lower($7) AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN $8 :: text != '' THEN + users.email = $8 + ELSE true + END + -- Filter by connected_after + AND CASE + WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time >= $9 + ELSE true + END + -- Filter by connected_before + AND CASE + WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time <= $10 + ELSE true + END + -- Filter by workspace_id + AND CASE + WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.workspace_id = $11 + ELSE true + END + -- Filter by connection_id + AND CASE + WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.connection_id = $12 + ELSE true + END + -- Filter by whether the session has a disconnect_time + AND CASE + WHEN $13 :: text != '' THEN + (($13 = 'ongoing' AND disconnect_time IS NULL) OR + ($13 = 'completed' AND disconnect_time IS NOT NULL)) AND + -- Exclude web events, since we don't know their close time. + "type" NOT IN ('workspace_app', 'port_forwarding') + ELSE true + END + -- Authorize Filter clause will be injected below in + -- CountAuthorizedConnectionLogs + -- @authorize_filter +` + +type CountConnectionLogsParams struct { + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + WorkspaceOwner string `db:"workspace_owner" json:"workspace_owner"` + WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"` + WorkspaceOwnerEmail string `db:"workspace_owner_email" json:"workspace_owner_email"` + Type string `db:"type" json:"type"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + Username string `db:"username" json:"username"` + UserEmail string `db:"user_email" json:"user_email"` + ConnectedAfter time.Time `db:"connected_after" json:"connected_after"` + ConnectedBefore time.Time `db:"connected_before" json:"connected_before"` + WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` + ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"` + Status string `db:"status" json:"status"` +} + +func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countConnectionLogs, + arg.OrganizationID, + arg.WorkspaceOwner, + arg.WorkspaceOwnerID, + arg.WorkspaceOwnerEmail, + arg.Type, + arg.UserID, + arg.Username, + arg.UserEmail, + arg.ConnectedAfter, + arg.ConnectedBefore, + arg.WorkspaceID, + arg.ConnectionID, + arg.Status, + ) + var count int64 + err := row.Scan(&count) + return count, err +} + const getConnectionLogsOffset = `-- name: GetConnectionLogsOffset :many SELECT connection_logs.id, connection_logs.connect_time, connection_logs.organization_id, connection_logs.workspace_owner_id, connection_logs.workspace_id, connection_logs.workspace_name, connection_logs.agent_name, connection_logs.type, connection_logs.ip, connection_logs.code, connection_logs.user_agent, connection_logs.user_id, connection_logs.slug_or_port, connection_logs.connection_id, connection_logs.disconnect_time, connection_logs.disconnect_reason, diff --git a/coderd/database/queries/connectionlogs.sql b/coderd/database/queries/connectionlogs.sql index e3f231a6b738e..eb2d1b0cb171a 100644 --- a/coderd/database/queries/connectionlogs.sql +++ b/coderd/database/queries/connectionlogs.sql @@ -132,6 +132,112 @@ LIMIT OFFSET @offset_opt; +-- name: CountConnectionLogs :one +SELECT + COUNT(*) AS count +FROM + connection_logs +JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id +LEFT JOIN users ON + connection_logs.user_id = users.id +JOIN organizations ON + connection_logs.organization_id = organizations.id +WHERE + -- Filter organization_id + CASE + WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.organization_id = @organization_id + ELSE true + END + -- Filter by workspace owner username + AND CASE + WHEN @workspace_owner :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE lower(username) = lower(@workspace_owner) AND deleted = false + ) + ELSE true + END + -- Filter by workspace_owner_id + AND CASE + WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + workspace_owner_id = @workspace_owner_id + ELSE true + END + -- Filter by workspace_owner_email + AND CASE + WHEN @workspace_owner_email :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE email = @workspace_owner_email AND deleted = false + ) + ELSE true + END + -- Filter by type + AND CASE + WHEN @type :: text != '' THEN + type = @type :: connection_type + ELSE true + END + -- Filter by user_id + AND CASE + WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = @user_id + ELSE true + END + -- Filter by username + AND CASE + WHEN @username :: text != '' THEN + user_id = ( + SELECT id FROM users + WHERE lower(username) = lower(@username) AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN @user_email :: text != '' THEN + users.email = @user_email + ELSE true + END + -- Filter by connected_after + AND CASE + WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time >= @connected_after + ELSE true + END + -- Filter by connected_before + AND CASE + WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time <= @connected_before + ELSE true + END + -- Filter by workspace_id + AND CASE + WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.workspace_id = @workspace_id + ELSE true + END + -- Filter by connection_id + AND CASE + WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.connection_id = @connection_id + ELSE true + END + -- Filter by whether the session has a disconnect_time + AND CASE + WHEN @status :: text != '' THEN + ((@status = 'ongoing' AND disconnect_time IS NULL) OR + (@status = 'completed' AND disconnect_time IS NOT NULL)) AND + -- Exclude web events, since we don't know their close time. + "type" NOT IN ('workspace_app', 'port_forwarding') + ELSE true + END + -- Authorize Filter clause will be injected below in + -- CountAuthorizedConnectionLogs + -- @authorize_filter +; -- name: UpsertConnectionLog :one INSERT INTO connection_logs ( diff --git a/coderd/searchquery/search.go b/coderd/searchquery/search.go index c17b3db77bdc5..d35f3c94b5ff7 100644 --- a/coderd/searchquery/search.go +++ b/coderd/searchquery/search.go @@ -86,7 +86,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G return filter, countFilter, parser.Errors } -func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey database.APIKey) (database.GetConnectionLogsOffsetParams, []codersdk.ValidationError) { +func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey database.APIKey) (database.GetConnectionLogsOffsetParams, database.CountConnectionLogsParams, []codersdk.ValidationError) { // Always lowercase for all searches. query = strings.ToLower(query) values, errors := searchTerms(query, func(term string, values url.Values) error { @@ -94,7 +94,8 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey return nil }) if len(errors) > 0 { - return database.GetConnectionLogsOffsetParams{}, errors + // nolint:exhaustruct // We don't need to initialize these structs because we return an error. + return database.GetConnectionLogsOffsetParams{}, database.CountConnectionLogsParams{}, errors } parser := httpapi.NewQueryParamParser() @@ -122,8 +123,24 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey filter.WorkspaceOwner = "" } + // This MUST be kept in sync with the above + countFilter := database.CountConnectionLogsParams{ + OrganizationID: filter.OrganizationID, + WorkspaceOwner: filter.WorkspaceOwner, + WorkspaceOwnerID: filter.WorkspaceOwnerID, + WorkspaceOwnerEmail: filter.WorkspaceOwnerEmail, + Type: filter.Type, + UserID: filter.UserID, + Username: filter.Username, + UserEmail: filter.UserEmail, + ConnectedAfter: filter.ConnectedAfter, + ConnectedBefore: filter.ConnectedBefore, + WorkspaceID: filter.WorkspaceID, + ConnectionID: filter.ConnectionID, + Status: filter.Status, + } parser.ErrorExcessParams(values) - return filter, parser.Errors + return filter, countFilter, parser.Errors } func Users(query string) (database.GetUsersParams, []codersdk.ValidationError) { diff --git a/coderd/searchquery/search_test.go b/coderd/searchquery/search_test.go index c251a4cd5bd90..4744b57edff4a 100644 --- a/coderd/searchquery/search_test.go +++ b/coderd/searchquery/search_test.go @@ -435,7 +435,7 @@ func TestSearchConnectionLogs(t *testing.T) { `connected_before:"2023-01-16T12:00:00+12:00" workspace_id:%s connection_id:%s status:ongoing`, workspaceID.String(), connectionID.String()) - values, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{}) + values, _, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{}) require.Len(t, errs, 0) expected := database.GetConnectionLogsOffsetParams{ @@ -462,7 +462,7 @@ func TestSearchConnectionLogs(t *testing.T) { db, _ := dbtestutil.NewDB(t) query := `username:me workspace_owner:me` - values, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{UserID: userID}) + values, _, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{UserID: userID}) require.Len(t, errs, 0) expected := database.GetConnectionLogsOffsetParams{ diff --git a/enterprise/coderd/connectionlog.go b/enterprise/coderd/connectionlog.go index 75413b82708fb..21f0420f0652d 100644 --- a/enterprise/coderd/connectionlog.go +++ b/enterprise/coderd/connectionlog.go @@ -36,7 +36,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { } queryStr := r.URL.Query().Get("q") - filter, errs := searchquery.ConnectionLogs(ctx, api.Database, queryStr, apiKey) + filter, countFilter, errs := searchquery.ConnectionLogs(ctx, api.Database, queryStr, apiKey) if len(errs) > 0 { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid connection search query.", @@ -49,6 +49,24 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { // #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range filter.LimitOpt = int32(page.Limit) + count, err := api.Database.CountConnectionLogs(ctx, countFilter) + if dbauthz.IsNotAuthorizedError(err) { + httpapi.Forbidden(rw) + return + } + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + if count == 0 { + httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{ + ConnectionLogs: []codersdk.ConnectionLog{}, + Count: 0, + }) + return + } + dblogs, err := api.Database.GetConnectionLogsOffset(ctx, filter) if dbauthz.IsNotAuthorizedError(err) { httpapi.Forbidden(rw) @@ -61,7 +79,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{ ConnectionLogs: convertConnectionLogs(dblogs), - Count: 0, // TODO(ethanndickson): Set count + Count: count, }) } diff --git a/enterprise/coderd/connectionlog_test.go b/enterprise/coderd/connectionlog_test.go index b94b2449f37c4..59ff1b780e7b6 100644 --- a/enterprise/coderd/connectionlog_test.go +++ b/enterprise/coderd/connectionlog_test.go @@ -65,6 +65,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.Equal(t, codersdk.ConnectionTypeSSH, logs.ConnectionLogs[0].Type) }) @@ -84,7 +85,7 @@ func TestConnectionLogs(t *testing.T) { logs, err := client.ConnectionLogs(ctx, codersdk.ConnectionLogsRequest{}) require.NoError(t, err) - + require.EqualValues(t, 0, logs.Count) require.Len(t, logs.ConnectionLogs, 0) }) @@ -133,6 +134,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.Equal(t, ws.OrganizationID, logs.ConnectionLogs[0].Organization.ID) }) @@ -169,6 +171,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.NotNil(t, logs.ConnectionLogs[0].WebInfo) require.Equal(t, clog.SlugOrPort.String, logs.ConnectionLogs[0].WebInfo.SlugOrPort) require.Equal(t, clog.UserAgent.String, logs.ConnectionLogs[0].WebInfo.UserAgent) @@ -241,6 +244,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.NotNil(t, logs.ConnectionLogs[0].SSHInfo) require.Nil(t, logs.ConnectionLogs[0].WebInfo) require.Equal(t, codersdk.ConnectionTypeSSH, logs.ConnectionLogs[0].Type)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: