diff --git a/.github/workflows/backend-ci-cd.yml b/.github/workflows/backend-ci-cd.yml index a313152..178bb93 100644 --- a/.github/workflows/backend-ci-cd.yml +++ b/.github/workflows/backend-ci-cd.yml @@ -70,6 +70,16 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + lfs: true + + - name: Cache Git LFS + uses: actions/cache@v4 + with: + path: .git/lfs + key: ${{ runner.os }}-git-lfs-${{ hashFiles('.gitattributes') }} + restore-keys: | + ${{ runner.os }}-git-lfs- - name: Get version id: get_version diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 70b87e2..c8fbf60 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -40,10 +40,6 @@ func main() { panic(err) } - if err := s.modelV1.InitModel(); err != nil { - panic(err) - } - svc := service.NewService(service.WithPprof()) svc.Add(s) if err := svc.Run(); err != nil { diff --git a/backend/config/config.go b/backend/config/config.go index 9022a05..fec89bf 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -63,8 +63,9 @@ type Config struct { } `mapstructure:"init_model"` Extension struct { - Baseurl string `mapstructure:"baseurl"` - Limit int `mapstructure:"limit"` + Baseurl string `mapstructure:"baseurl"` + LimitSecond int `mapstructure:"limit_second"` + Limit int `mapstructure:"limit"` } `mapstructure:"extension"` } @@ -96,11 +97,12 @@ func Init() (*Config, error) { v.SetDefault("llm_proxy.client_pool_size", 100) v.SetDefault("llm_proxy.stream_client_pool_size", 5000) v.SetDefault("llm_proxy.request_log_path", "/app/request/logs") - v.SetDefault("init_model.name", "qwen2.5-coder-3b-instruct") + v.SetDefault("init_model.name", "") v.SetDefault("init_model.key", "") - v.SetDefault("init_model.url", "https://model-square.app.baizhi.cloud/v1") + v.SetDefault("init_model.url", "") v.SetDefault("extension.baseurl", "https://release.baizhi.cloud") - v.SetDefault("extension.limit", 10) + v.SetDefault("extension.limit", 1) + v.SetDefault("extension.limit_second", 10) c := Config{} if err := v.Unmarshal(&c); err != nil { diff --git a/backend/consts/openai.go b/backend/consts/openai.go index 68a1815..935f147 100644 --- a/backend/consts/openai.go +++ b/backend/consts/openai.go @@ -12,4 +12,5 @@ type ChatRole string const ( ChatRoleUser ChatRole = "user" ChatRoleAssistant ChatRole = "assistant" + ChatRoleSystem ChatRole = "system" ) diff --git a/backend/consts/proxy.go b/backend/consts/proxy.go new file mode 100644 index 0000000..f4c72eb --- /dev/null +++ b/backend/consts/proxy.go @@ -0,0 +1,9 @@ +package consts + +type ReportAction string + +const ( + ReportActionAccept ReportAction = "accept" + ReportActionSuggest ReportAction = "suggest" + ReportActionFileWritten ReportAction = "file_written" +) diff --git a/backend/db/migrate/schema.go b/backend/db/migrate/schema.go index 35f1d1b..0f58067 100644 --- a/backend/db/migrate/schema.go +++ b/backend/db/migrate/schema.go @@ -270,6 +270,7 @@ var ( {Name: "code_lines", Type: field.TypeInt64, Nullable: true}, {Name: "input_tokens", Type: field.TypeInt64, Nullable: true}, {Name: "output_tokens", Type: field.TypeInt64, Nullable: true}, + {Name: "is_suggested", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, {Name: "model_id", Type: field.TypeUUID, Nullable: true}, @@ -283,13 +284,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "tasks_models_tasks", - Columns: []*schema.Column{TasksColumns[13]}, + Columns: []*schema.Column{TasksColumns[14]}, RefColumns: []*schema.Column{ModelsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "tasks_users_tasks", - Columns: []*schema.Column{TasksColumns[14]}, + Columns: []*schema.Column{TasksColumns[15]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, @@ -302,6 +303,8 @@ var ( {Name: "role", Type: field.TypeString}, {Name: "completion", Type: field.TypeString}, {Name: "output_tokens", Type: field.TypeInt64}, + {Name: "code_lines", Type: field.TypeInt64}, + {Name: "code", Type: field.TypeString, Nullable: true}, {Name: "created_at", Type: field.TypeTime}, {Name: "updated_at", Type: field.TypeTime}, {Name: "task_id", Type: field.TypeUUID, Nullable: true}, @@ -314,7 +317,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "task_records_tasks_task_records", - Columns: []*schema.Column{TaskRecordsColumns[7]}, + Columns: []*schema.Column{TaskRecordsColumns[9]}, RefColumns: []*schema.Column{TasksColumns[0]}, OnDelete: schema.SetNull, }, diff --git a/backend/db/mutation.go b/backend/db/mutation.go index bc77a7e..9cfdc87 100644 --- a/backend/db/mutation.go +++ b/backend/db/mutation.go @@ -9876,6 +9876,7 @@ type TaskMutation struct { addinput_tokens *int64 output_tokens *int64 addoutput_tokens *int64 + is_suggested *bool created_at *time.Time updated_at *time.Time clearedFields map[string]struct{} @@ -10607,6 +10608,42 @@ func (m *TaskMutation) ResetOutputTokens() { delete(m.clearedFields, task.FieldOutputTokens) } +// SetIsSuggested sets the "is_suggested" field. +func (m *TaskMutation) SetIsSuggested(b bool) { + m.is_suggested = &b +} + +// IsSuggested returns the value of the "is_suggested" field in the mutation. +func (m *TaskMutation) IsSuggested() (r bool, exists bool) { + v := m.is_suggested + if v == nil { + return + } + return *v, true +} + +// OldIsSuggested returns the old "is_suggested" field's value of the Task entity. +// If the Task object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskMutation) OldIsSuggested(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIsSuggested is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIsSuggested requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIsSuggested: %w", err) + } + return oldValue.IsSuggested, nil +} + +// ResetIsSuggested resets all changes to the "is_suggested" field. +func (m *TaskMutation) ResetIsSuggested() { + m.is_suggested = nil +} + // SetCreatedAt sets the "created_at" field. func (m *TaskMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -10821,7 +10858,7 @@ func (m *TaskMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *TaskMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 15) if m.task_id != nil { fields = append(fields, task.FieldTaskID) } @@ -10858,6 +10895,9 @@ func (m *TaskMutation) Fields() []string { if m.output_tokens != nil { fields = append(fields, task.FieldOutputTokens) } + if m.is_suggested != nil { + fields = append(fields, task.FieldIsSuggested) + } if m.created_at != nil { fields = append(fields, task.FieldCreatedAt) } @@ -10896,6 +10936,8 @@ func (m *TaskMutation) Field(name string) (ent.Value, bool) { return m.InputTokens() case task.FieldOutputTokens: return m.OutputTokens() + case task.FieldIsSuggested: + return m.IsSuggested() case task.FieldCreatedAt: return m.CreatedAt() case task.FieldUpdatedAt: @@ -10933,6 +10975,8 @@ func (m *TaskMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldInputTokens(ctx) case task.FieldOutputTokens: return m.OldOutputTokens(ctx) + case task.FieldIsSuggested: + return m.OldIsSuggested(ctx) case task.FieldCreatedAt: return m.OldCreatedAt(ctx) case task.FieldUpdatedAt: @@ -11030,6 +11074,13 @@ func (m *TaskMutation) SetField(name string, value ent.Value) error { } m.SetOutputTokens(v) return nil + case task.FieldIsSuggested: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIsSuggested(v) + return nil case task.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -11225,6 +11276,9 @@ func (m *TaskMutation) ResetField(name string) error { case task.FieldOutputTokens: m.ResetOutputTokens() return nil + case task.FieldIsSuggested: + m.ResetIsSuggested() + return nil case task.FieldCreatedAt: m.ResetCreatedAt() return nil @@ -11366,6 +11420,9 @@ type TaskRecordMutation struct { completion *string output_tokens *int64 addoutput_tokens *int64 + code_lines *int64 + addcode_lines *int64 + code *string created_at *time.Time updated_at *time.Time clearedFields map[string]struct{} @@ -11706,6 +11763,111 @@ func (m *TaskRecordMutation) ResetOutputTokens() { m.addoutput_tokens = nil } +// SetCodeLines sets the "code_lines" field. +func (m *TaskRecordMutation) SetCodeLines(i int64) { + m.code_lines = &i + m.addcode_lines = nil +} + +// CodeLines returns the value of the "code_lines" field in the mutation. +func (m *TaskRecordMutation) CodeLines() (r int64, exists bool) { + v := m.code_lines + if v == nil { + return + } + return *v, true +} + +// OldCodeLines returns the old "code_lines" field's value of the TaskRecord entity. +// If the TaskRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskRecordMutation) OldCodeLines(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCodeLines is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCodeLines requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCodeLines: %w", err) + } + return oldValue.CodeLines, nil +} + +// AddCodeLines adds i to the "code_lines" field. +func (m *TaskRecordMutation) AddCodeLines(i int64) { + if m.addcode_lines != nil { + *m.addcode_lines += i + } else { + m.addcode_lines = &i + } +} + +// AddedCodeLines returns the value that was added to the "code_lines" field in this mutation. +func (m *TaskRecordMutation) AddedCodeLines() (r int64, exists bool) { + v := m.addcode_lines + if v == nil { + return + } + return *v, true +} + +// ResetCodeLines resets all changes to the "code_lines" field. +func (m *TaskRecordMutation) ResetCodeLines() { + m.code_lines = nil + m.addcode_lines = nil +} + +// SetCode sets the "code" field. +func (m *TaskRecordMutation) SetCode(s string) { + m.code = &s +} + +// Code returns the value of the "code" field in the mutation. +func (m *TaskRecordMutation) Code() (r string, exists bool) { + v := m.code + if v == nil { + return + } + return *v, true +} + +// OldCode returns the old "code" field's value of the TaskRecord entity. +// If the TaskRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *TaskRecordMutation) OldCode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCode: %w", err) + } + return oldValue.Code, nil +} + +// ClearCode clears the value of the "code" field. +func (m *TaskRecordMutation) ClearCode() { + m.code = nil + m.clearedFields[taskrecord.FieldCode] = struct{}{} +} + +// CodeCleared returns if the "code" field was cleared in this mutation. +func (m *TaskRecordMutation) CodeCleared() bool { + _, ok := m.clearedFields[taskrecord.FieldCode] + return ok +} + +// ResetCode resets all changes to the "code" field. +func (m *TaskRecordMutation) ResetCode() { + m.code = nil + delete(m.clearedFields, taskrecord.FieldCode) +} + // SetCreatedAt sets the "created_at" field. func (m *TaskRecordMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -11839,7 +12001,7 @@ func (m *TaskRecordMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *TaskRecordMutation) Fields() []string { - fields := make([]string, 0, 7) + fields := make([]string, 0, 9) if m.task != nil { fields = append(fields, taskrecord.FieldTaskID) } @@ -11855,6 +12017,12 @@ func (m *TaskRecordMutation) Fields() []string { if m.output_tokens != nil { fields = append(fields, taskrecord.FieldOutputTokens) } + if m.code_lines != nil { + fields = append(fields, taskrecord.FieldCodeLines) + } + if m.code != nil { + fields = append(fields, taskrecord.FieldCode) + } if m.created_at != nil { fields = append(fields, taskrecord.FieldCreatedAt) } @@ -11879,6 +12047,10 @@ func (m *TaskRecordMutation) Field(name string) (ent.Value, bool) { return m.Completion() case taskrecord.FieldOutputTokens: return m.OutputTokens() + case taskrecord.FieldCodeLines: + return m.CodeLines() + case taskrecord.FieldCode: + return m.Code() case taskrecord.FieldCreatedAt: return m.CreatedAt() case taskrecord.FieldUpdatedAt: @@ -11902,6 +12074,10 @@ func (m *TaskRecordMutation) OldField(ctx context.Context, name string) (ent.Val return m.OldCompletion(ctx) case taskrecord.FieldOutputTokens: return m.OldOutputTokens(ctx) + case taskrecord.FieldCodeLines: + return m.OldCodeLines(ctx) + case taskrecord.FieldCode: + return m.OldCode(ctx) case taskrecord.FieldCreatedAt: return m.OldCreatedAt(ctx) case taskrecord.FieldUpdatedAt: @@ -11950,6 +12126,20 @@ func (m *TaskRecordMutation) SetField(name string, value ent.Value) error { } m.SetOutputTokens(v) return nil + case taskrecord.FieldCodeLines: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCodeLines(v) + return nil + case taskrecord.FieldCode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCode(v) + return nil case taskrecord.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -11975,6 +12165,9 @@ func (m *TaskRecordMutation) AddedFields() []string { if m.addoutput_tokens != nil { fields = append(fields, taskrecord.FieldOutputTokens) } + if m.addcode_lines != nil { + fields = append(fields, taskrecord.FieldCodeLines) + } return fields } @@ -11985,6 +12178,8 @@ func (m *TaskRecordMutation) AddedField(name string) (ent.Value, bool) { switch name { case taskrecord.FieldOutputTokens: return m.AddedOutputTokens() + case taskrecord.FieldCodeLines: + return m.AddedCodeLines() } return nil, false } @@ -12001,6 +12196,13 @@ func (m *TaskRecordMutation) AddField(name string, value ent.Value) error { } m.AddOutputTokens(v) return nil + case taskrecord.FieldCodeLines: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCodeLines(v) + return nil } return fmt.Errorf("unknown TaskRecord numeric field %s", name) } @@ -12015,6 +12217,9 @@ func (m *TaskRecordMutation) ClearedFields() []string { if m.FieldCleared(taskrecord.FieldPrompt) { fields = append(fields, taskrecord.FieldPrompt) } + if m.FieldCleared(taskrecord.FieldCode) { + fields = append(fields, taskrecord.FieldCode) + } return fields } @@ -12035,6 +12240,9 @@ func (m *TaskRecordMutation) ClearField(name string) error { case taskrecord.FieldPrompt: m.ClearPrompt() return nil + case taskrecord.FieldCode: + m.ClearCode() + return nil } return fmt.Errorf("unknown TaskRecord nullable field %s", name) } @@ -12058,6 +12266,12 @@ func (m *TaskRecordMutation) ResetField(name string) error { case taskrecord.FieldOutputTokens: m.ResetOutputTokens() return nil + case taskrecord.FieldCodeLines: + m.ResetCodeLines() + return nil + case taskrecord.FieldCode: + m.ResetCode() + return nil case taskrecord.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/backend/db/runtime/runtime.go b/backend/db/runtime/runtime.go index 3d2d22a..8131d59 100644 --- a/backend/db/runtime/runtime.go +++ b/backend/db/runtime/runtime.go @@ -256,12 +256,16 @@ func init() { taskDescIsAccept := taskFields[6].Descriptor() // task.DefaultIsAccept holds the default value on creation for the is_accept field. task.DefaultIsAccept = taskDescIsAccept.Default.(bool) + // taskDescIsSuggested is the schema descriptor for is_suggested field. + taskDescIsSuggested := taskFields[13].Descriptor() + // task.DefaultIsSuggested holds the default value on creation for the is_suggested field. + task.DefaultIsSuggested = taskDescIsSuggested.Default.(bool) // taskDescCreatedAt is the schema descriptor for created_at field. - taskDescCreatedAt := taskFields[13].Descriptor() + taskDescCreatedAt := taskFields[14].Descriptor() // task.DefaultCreatedAt holds the default value on creation for the created_at field. task.DefaultCreatedAt = taskDescCreatedAt.Default.(func() time.Time) // taskDescUpdatedAt is the schema descriptor for updated_at field. - taskDescUpdatedAt := taskFields[14].Descriptor() + taskDescUpdatedAt := taskFields[15].Descriptor() // task.DefaultUpdatedAt holds the default value on creation for the updated_at field. task.DefaultUpdatedAt = taskDescUpdatedAt.Default.(func() time.Time) // task.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -269,11 +273,11 @@ func init() { taskrecordFields := schema.TaskRecord{}.Fields() _ = taskrecordFields // taskrecordDescCreatedAt is the schema descriptor for created_at field. - taskrecordDescCreatedAt := taskrecordFields[6].Descriptor() + taskrecordDescCreatedAt := taskrecordFields[8].Descriptor() // taskrecord.DefaultCreatedAt holds the default value on creation for the created_at field. taskrecord.DefaultCreatedAt = taskrecordDescCreatedAt.Default.(func() time.Time) // taskrecordDescUpdatedAt is the schema descriptor for updated_at field. - taskrecordDescUpdatedAt := taskrecordFields[7].Descriptor() + taskrecordDescUpdatedAt := taskrecordFields[9].Descriptor() // taskrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field. taskrecord.DefaultUpdatedAt = taskrecordDescUpdatedAt.Default.(func() time.Time) // taskrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. diff --git a/backend/db/task.go b/backend/db/task.go index 043c9c3..d9902b5 100644 --- a/backend/db/task.go +++ b/backend/db/task.go @@ -45,6 +45,8 @@ type Task struct { InputTokens int64 `json:"input_tokens,omitempty"` // OutputTokens holds the value of the "output_tokens" field. OutputTokens int64 `json:"output_tokens,omitempty"` + // IsSuggested holds the value of the "is_suggested" field. + IsSuggested bool `json:"is_suggested,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -104,7 +106,7 @@ func (*Task) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case task.FieldIsAccept: + case task.FieldIsAccept, task.FieldIsSuggested: values[i] = new(sql.NullBool) case task.FieldCodeLines, task.FieldInputTokens, task.FieldOutputTokens: values[i] = new(sql.NullInt64) @@ -207,6 +209,12 @@ func (t *Task) assignValues(columns []string, values []any) error { } else if value.Valid { t.OutputTokens = value.Int64 } + case task.FieldIsSuggested: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field is_suggested", values[i]) + } else if value.Valid { + t.IsSuggested = value.Bool + } case task.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -306,6 +314,9 @@ func (t *Task) String() string { builder.WriteString("output_tokens=") builder.WriteString(fmt.Sprintf("%v", t.OutputTokens)) builder.WriteString(", ") + builder.WriteString("is_suggested=") + builder.WriteString(fmt.Sprintf("%v", t.IsSuggested)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(t.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/db/task/task.go b/backend/db/task/task.go index fed0b73..cba7158 100644 --- a/backend/db/task/task.go +++ b/backend/db/task/task.go @@ -38,6 +38,8 @@ const ( FieldInputTokens = "input_tokens" // FieldOutputTokens holds the string denoting the output_tokens field in the database. FieldOutputTokens = "output_tokens" + // FieldIsSuggested holds the string denoting the is_suggested field in the database. + FieldIsSuggested = "is_suggested" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -88,6 +90,7 @@ var Columns = []string{ FieldCodeLines, FieldInputTokens, FieldOutputTokens, + FieldIsSuggested, FieldCreatedAt, FieldUpdatedAt, } @@ -105,6 +108,8 @@ func ValidColumn(column string) bool { var ( // DefaultIsAccept holds the default value on creation for the "is_accept" field. DefaultIsAccept bool + // DefaultIsSuggested holds the default value on creation for the "is_suggested" field. + DefaultIsSuggested bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -181,6 +186,11 @@ func ByOutputTokens(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOutputTokens, opts...).ToFunc() } +// ByIsSuggested orders the results by the is_suggested field. +func ByIsSuggested(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIsSuggested, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/db/task/where.go b/backend/db/task/where.go index 5321211..5f37f73 100644 --- a/backend/db/task/where.go +++ b/backend/db/task/where.go @@ -118,6 +118,11 @@ func OutputTokens(v int64) predicate.Task { return predicate.Task(sql.FieldEQ(FieldOutputTokens, v)) } +// IsSuggested applies equality check predicate on the "is_suggested" field. It's identical to IsSuggestedEQ. +func IsSuggested(v bool) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldIsSuggested, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.Task { return predicate.Task(sql.FieldEQ(FieldCreatedAt, v)) @@ -797,6 +802,16 @@ func OutputTokensNotNil() predicate.Task { return predicate.Task(sql.FieldNotNull(FieldOutputTokens)) } +// IsSuggestedEQ applies the EQ predicate on the "is_suggested" field. +func IsSuggestedEQ(v bool) predicate.Task { + return predicate.Task(sql.FieldEQ(FieldIsSuggested, v)) +} + +// IsSuggestedNEQ applies the NEQ predicate on the "is_suggested" field. +func IsSuggestedNEQ(v bool) predicate.Task { + return predicate.Task(sql.FieldNEQ(FieldIsSuggested, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Task { return predicate.Task(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/db/task_create.go b/backend/db/task_create.go index adcfb3d..50d4335 100644 --- a/backend/db/task_create.go +++ b/backend/db/task_create.go @@ -180,6 +180,20 @@ func (tc *TaskCreate) SetNillableOutputTokens(i *int64) *TaskCreate { return tc } +// SetIsSuggested sets the "is_suggested" field. +func (tc *TaskCreate) SetIsSuggested(b bool) *TaskCreate { + tc.mutation.SetIsSuggested(b) + return tc +} + +// SetNillableIsSuggested sets the "is_suggested" field if the given value is not nil. +func (tc *TaskCreate) SetNillableIsSuggested(b *bool) *TaskCreate { + if b != nil { + tc.SetIsSuggested(*b) + } + return tc +} + // SetCreatedAt sets the "created_at" field. func (tc *TaskCreate) SetCreatedAt(t time.Time) *TaskCreate { tc.mutation.SetCreatedAt(t) @@ -278,6 +292,10 @@ func (tc *TaskCreate) defaults() { v := task.DefaultIsAccept tc.mutation.SetIsAccept(v) } + if _, ok := tc.mutation.IsSuggested(); !ok { + v := task.DefaultIsSuggested + tc.mutation.SetIsSuggested(v) + } if _, ok := tc.mutation.CreatedAt(); !ok { v := task.DefaultCreatedAt() tc.mutation.SetCreatedAt(v) @@ -299,6 +317,9 @@ func (tc *TaskCreate) check() error { if _, ok := tc.mutation.IsAccept(); !ok { return &ValidationError{Name: "is_accept", err: errors.New(`db: missing required field "Task.is_accept"`)} } + if _, ok := tc.mutation.IsSuggested(); !ok { + return &ValidationError{Name: "is_suggested", err: errors.New(`db: missing required field "Task.is_suggested"`)} + } if _, ok := tc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "Task.created_at"`)} } @@ -381,6 +402,10 @@ func (tc *TaskCreate) createSpec() (*Task, *sqlgraph.CreateSpec) { _spec.SetField(task.FieldOutputTokens, field.TypeInt64, value) _node.OutputTokens = value } + if value, ok := tc.mutation.IsSuggested(); ok { + _spec.SetField(task.FieldIsSuggested, field.TypeBool, value) + _node.IsSuggested = value + } if value, ok := tc.mutation.CreatedAt(); ok { _spec.SetField(task.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -707,6 +732,18 @@ func (u *TaskUpsert) ClearOutputTokens() *TaskUpsert { return u } +// SetIsSuggested sets the "is_suggested" field. +func (u *TaskUpsert) SetIsSuggested(v bool) *TaskUpsert { + u.Set(task.FieldIsSuggested, v) + return u +} + +// UpdateIsSuggested sets the "is_suggested" field to the value that was provided on create. +func (u *TaskUpsert) UpdateIsSuggested() *TaskUpsert { + u.SetExcluded(task.FieldIsSuggested) + return u +} + // SetCreatedAt sets the "created_at" field. func (u *TaskUpsert) SetCreatedAt(v time.Time) *TaskUpsert { u.Set(task.FieldCreatedAt, v) @@ -1031,6 +1068,20 @@ func (u *TaskUpsertOne) ClearOutputTokens() *TaskUpsertOne { }) } +// SetIsSuggested sets the "is_suggested" field. +func (u *TaskUpsertOne) SetIsSuggested(v bool) *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.SetIsSuggested(v) + }) +} + +// UpdateIsSuggested sets the "is_suggested" field to the value that was provided on create. +func (u *TaskUpsertOne) UpdateIsSuggested() *TaskUpsertOne { + return u.Update(func(s *TaskUpsert) { + s.UpdateIsSuggested() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskUpsertOne) SetCreatedAt(v time.Time) *TaskUpsertOne { return u.Update(func(s *TaskUpsert) { @@ -1526,6 +1577,20 @@ func (u *TaskUpsertBulk) ClearOutputTokens() *TaskUpsertBulk { }) } +// SetIsSuggested sets the "is_suggested" field. +func (u *TaskUpsertBulk) SetIsSuggested(v bool) *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.SetIsSuggested(v) + }) +} + +// UpdateIsSuggested sets the "is_suggested" field to the value that was provided on create. +func (u *TaskUpsertBulk) UpdateIsSuggested() *TaskUpsertBulk { + return u.Update(func(s *TaskUpsert) { + s.UpdateIsSuggested() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskUpsertBulk) SetCreatedAt(v time.Time) *TaskUpsertBulk { return u.Update(func(s *TaskUpsert) { diff --git a/backend/db/task_update.go b/backend/db/task_update.go index dfddaa9..d202867 100644 --- a/backend/db/task_update.go +++ b/backend/db/task_update.go @@ -277,6 +277,20 @@ func (tu *TaskUpdate) ClearOutputTokens() *TaskUpdate { return tu } +// SetIsSuggested sets the "is_suggested" field. +func (tu *TaskUpdate) SetIsSuggested(b bool) *TaskUpdate { + tu.mutation.SetIsSuggested(b) + return tu +} + +// SetNillableIsSuggested sets the "is_suggested" field if the given value is not nil. +func (tu *TaskUpdate) SetNillableIsSuggested(b *bool) *TaskUpdate { + if b != nil { + tu.SetIsSuggested(*b) + } + return tu +} + // SetCreatedAt sets the "created_at" field. func (tu *TaskUpdate) SetCreatedAt(t time.Time) *TaskUpdate { tu.mutation.SetCreatedAt(t) @@ -471,6 +485,9 @@ func (tu *TaskUpdate) sqlSave(ctx context.Context) (n int, err error) { if tu.mutation.OutputTokensCleared() { _spec.ClearField(task.FieldOutputTokens, field.TypeInt64) } + if value, ok := tu.mutation.IsSuggested(); ok { + _spec.SetField(task.FieldIsSuggested, field.TypeBool, value) + } if value, ok := tu.mutation.CreatedAt(); ok { _spec.SetField(task.FieldCreatedAt, field.TypeTime, value) } @@ -845,6 +862,20 @@ func (tuo *TaskUpdateOne) ClearOutputTokens() *TaskUpdateOne { return tuo } +// SetIsSuggested sets the "is_suggested" field. +func (tuo *TaskUpdateOne) SetIsSuggested(b bool) *TaskUpdateOne { + tuo.mutation.SetIsSuggested(b) + return tuo +} + +// SetNillableIsSuggested sets the "is_suggested" field if the given value is not nil. +func (tuo *TaskUpdateOne) SetNillableIsSuggested(b *bool) *TaskUpdateOne { + if b != nil { + tuo.SetIsSuggested(*b) + } + return tuo +} + // SetCreatedAt sets the "created_at" field. func (tuo *TaskUpdateOne) SetCreatedAt(t time.Time) *TaskUpdateOne { tuo.mutation.SetCreatedAt(t) @@ -1069,6 +1100,9 @@ func (tuo *TaskUpdateOne) sqlSave(ctx context.Context) (_node *Task, err error) if tuo.mutation.OutputTokensCleared() { _spec.ClearField(task.FieldOutputTokens, field.TypeInt64) } + if value, ok := tuo.mutation.IsSuggested(); ok { + _spec.SetField(task.FieldIsSuggested, field.TypeBool, value) + } if value, ok := tuo.mutation.CreatedAt(); ok { _spec.SetField(task.FieldCreatedAt, field.TypeTime, value) } diff --git a/backend/db/taskrecord.go b/backend/db/taskrecord.go index 88efe18..6a270bd 100644 --- a/backend/db/taskrecord.go +++ b/backend/db/taskrecord.go @@ -30,6 +30,10 @@ type TaskRecord struct { Completion string `json:"completion,omitempty"` // OutputTokens holds the value of the "output_tokens" field. OutputTokens int64 `json:"output_tokens,omitempty"` + // CodeLines holds the value of the "code_lines" field. + CodeLines int64 `json:"code_lines,omitempty"` + // Code holds the value of the "code" field. + Code string `json:"code,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -65,9 +69,9 @@ func (*TaskRecord) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case taskrecord.FieldOutputTokens: + case taskrecord.FieldOutputTokens, taskrecord.FieldCodeLines: values[i] = new(sql.NullInt64) - case taskrecord.FieldPrompt, taskrecord.FieldRole, taskrecord.FieldCompletion: + case taskrecord.FieldPrompt, taskrecord.FieldRole, taskrecord.FieldCompletion, taskrecord.FieldCode: values[i] = new(sql.NullString) case taskrecord.FieldCreatedAt, taskrecord.FieldUpdatedAt: values[i] = new(sql.NullTime) @@ -124,6 +128,18 @@ func (tr *TaskRecord) assignValues(columns []string, values []any) error { } else if value.Valid { tr.OutputTokens = value.Int64 } + case taskrecord.FieldCodeLines: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field code_lines", values[i]) + } else if value.Valid { + tr.CodeLines = value.Int64 + } + case taskrecord.FieldCode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field code", values[i]) + } else if value.Valid { + tr.Code = value.String + } case taskrecord.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -192,6 +208,12 @@ func (tr *TaskRecord) String() string { builder.WriteString("output_tokens=") builder.WriteString(fmt.Sprintf("%v", tr.OutputTokens)) builder.WriteString(", ") + builder.WriteString("code_lines=") + builder.WriteString(fmt.Sprintf("%v", tr.CodeLines)) + builder.WriteString(", ") + builder.WriteString("code=") + builder.WriteString(tr.Code) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(tr.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/db/taskrecord/taskrecord.go b/backend/db/taskrecord/taskrecord.go index bca05dd..e290321 100644 --- a/backend/db/taskrecord/taskrecord.go +++ b/backend/db/taskrecord/taskrecord.go @@ -24,6 +24,10 @@ const ( FieldCompletion = "completion" // FieldOutputTokens holds the string denoting the output_tokens field in the database. FieldOutputTokens = "output_tokens" + // FieldCodeLines holds the string denoting the code_lines field in the database. + FieldCodeLines = "code_lines" + // FieldCode holds the string denoting the code field in the database. + FieldCode = "code" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -49,6 +53,8 @@ var Columns = []string{ FieldRole, FieldCompletion, FieldOutputTokens, + FieldCodeLines, + FieldCode, FieldCreatedAt, FieldUpdatedAt, } @@ -105,6 +111,16 @@ func ByOutputTokens(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOutputTokens, opts...).ToFunc() } +// ByCodeLines orders the results by the code_lines field. +func ByCodeLines(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCodeLines, opts...).ToFunc() +} + +// ByCode orders the results by the code field. +func ByCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCode, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/db/taskrecord/where.go b/backend/db/taskrecord/where.go index 38a4a74..a8736f8 100644 --- a/backend/db/taskrecord/where.go +++ b/backend/db/taskrecord/where.go @@ -83,6 +83,16 @@ func OutputTokens(v int64) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldOutputTokens, v)) } +// CodeLines applies equality check predicate on the "code_lines" field. It's identical to CodeLinesEQ. +func CodeLines(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCodeLines, v)) +} + +// Code applies equality check predicate on the "code" field. It's identical to CodeEQ. +func Code(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCode, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldCreatedAt, v)) @@ -387,6 +397,121 @@ func OutputTokensLTE(v int64) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldLTE(FieldOutputTokens, v)) } +// CodeLinesEQ applies the EQ predicate on the "code_lines" field. +func CodeLinesEQ(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCodeLines, v)) +} + +// CodeLinesNEQ applies the NEQ predicate on the "code_lines" field. +func CodeLinesNEQ(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNEQ(FieldCodeLines, v)) +} + +// CodeLinesIn applies the In predicate on the "code_lines" field. +func CodeLinesIn(vs ...int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIn(FieldCodeLines, vs...)) +} + +// CodeLinesNotIn applies the NotIn predicate on the "code_lines" field. +func CodeLinesNotIn(vs ...int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotIn(FieldCodeLines, vs...)) +} + +// CodeLinesGT applies the GT predicate on the "code_lines" field. +func CodeLinesGT(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGT(FieldCodeLines, v)) +} + +// CodeLinesGTE applies the GTE predicate on the "code_lines" field. +func CodeLinesGTE(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGTE(FieldCodeLines, v)) +} + +// CodeLinesLT applies the LT predicate on the "code_lines" field. +func CodeLinesLT(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLT(FieldCodeLines, v)) +} + +// CodeLinesLTE applies the LTE predicate on the "code_lines" field. +func CodeLinesLTE(v int64) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLTE(FieldCodeLines, v)) +} + +// CodeEQ applies the EQ predicate on the "code" field. +func CodeEQ(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEQ(FieldCode, v)) +} + +// CodeNEQ applies the NEQ predicate on the "code" field. +func CodeNEQ(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNEQ(FieldCode, v)) +} + +// CodeIn applies the In predicate on the "code" field. +func CodeIn(vs ...string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIn(FieldCode, vs...)) +} + +// CodeNotIn applies the NotIn predicate on the "code" field. +func CodeNotIn(vs ...string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotIn(FieldCode, vs...)) +} + +// CodeGT applies the GT predicate on the "code" field. +func CodeGT(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGT(FieldCode, v)) +} + +// CodeGTE applies the GTE predicate on the "code" field. +func CodeGTE(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldGTE(FieldCode, v)) +} + +// CodeLT applies the LT predicate on the "code" field. +func CodeLT(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLT(FieldCode, v)) +} + +// CodeLTE applies the LTE predicate on the "code" field. +func CodeLTE(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldLTE(FieldCode, v)) +} + +// CodeContains applies the Contains predicate on the "code" field. +func CodeContains(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldContains(FieldCode, v)) +} + +// CodeHasPrefix applies the HasPrefix predicate on the "code" field. +func CodeHasPrefix(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldHasPrefix(FieldCode, v)) +} + +// CodeHasSuffix applies the HasSuffix predicate on the "code" field. +func CodeHasSuffix(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldHasSuffix(FieldCode, v)) +} + +// CodeIsNil applies the IsNil predicate on the "code" field. +func CodeIsNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldIsNull(FieldCode)) +} + +// CodeNotNil applies the NotNil predicate on the "code" field. +func CodeNotNil() predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldNotNull(FieldCode)) +} + +// CodeEqualFold applies the EqualFold predicate on the "code" field. +func CodeEqualFold(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldEqualFold(FieldCode, v)) +} + +// CodeContainsFold applies the ContainsFold predicate on the "code" field. +func CodeContainsFold(v string) predicate.TaskRecord { + return predicate.TaskRecord(sql.FieldContainsFold(FieldCode, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.TaskRecord { return predicate.TaskRecord(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/db/taskrecord_create.go b/backend/db/taskrecord_create.go index aa78db4..e7b89bc 100644 --- a/backend/db/taskrecord_create.go +++ b/backend/db/taskrecord_create.go @@ -72,6 +72,26 @@ func (trc *TaskRecordCreate) SetOutputTokens(i int64) *TaskRecordCreate { return trc } +// SetCodeLines sets the "code_lines" field. +func (trc *TaskRecordCreate) SetCodeLines(i int64) *TaskRecordCreate { + trc.mutation.SetCodeLines(i) + return trc +} + +// SetCode sets the "code" field. +func (trc *TaskRecordCreate) SetCode(s string) *TaskRecordCreate { + trc.mutation.SetCode(s) + return trc +} + +// SetNillableCode sets the "code" field if the given value is not nil. +func (trc *TaskRecordCreate) SetNillableCode(s *string) *TaskRecordCreate { + if s != nil { + trc.SetCode(*s) + } + return trc +} + // SetCreatedAt sets the "created_at" field. func (trc *TaskRecordCreate) SetCreatedAt(t time.Time) *TaskRecordCreate { trc.mutation.SetCreatedAt(t) @@ -167,6 +187,9 @@ func (trc *TaskRecordCreate) check() error { if _, ok := trc.mutation.OutputTokens(); !ok { return &ValidationError{Name: "output_tokens", err: errors.New(`db: missing required field "TaskRecord.output_tokens"`)} } + if _, ok := trc.mutation.CodeLines(); !ok { + return &ValidationError{Name: "code_lines", err: errors.New(`db: missing required field "TaskRecord.code_lines"`)} + } if _, ok := trc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "TaskRecord.created_at"`)} } @@ -225,6 +248,14 @@ func (trc *TaskRecordCreate) createSpec() (*TaskRecord, *sqlgraph.CreateSpec) { _spec.SetField(taskrecord.FieldOutputTokens, field.TypeInt64, value) _node.OutputTokens = value } + if value, ok := trc.mutation.CodeLines(); ok { + _spec.SetField(taskrecord.FieldCodeLines, field.TypeInt64, value) + _node.CodeLines = value + } + if value, ok := trc.mutation.Code(); ok { + _spec.SetField(taskrecord.FieldCode, field.TypeString, value) + _node.Code = value + } if value, ok := trc.mutation.CreatedAt(); ok { _spec.SetField(taskrecord.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -380,6 +411,42 @@ func (u *TaskRecordUpsert) AddOutputTokens(v int64) *TaskRecordUpsert { return u } +// SetCodeLines sets the "code_lines" field. +func (u *TaskRecordUpsert) SetCodeLines(v int64) *TaskRecordUpsert { + u.Set(taskrecord.FieldCodeLines, v) + return u +} + +// UpdateCodeLines sets the "code_lines" field to the value that was provided on create. +func (u *TaskRecordUpsert) UpdateCodeLines() *TaskRecordUpsert { + u.SetExcluded(taskrecord.FieldCodeLines) + return u +} + +// AddCodeLines adds v to the "code_lines" field. +func (u *TaskRecordUpsert) AddCodeLines(v int64) *TaskRecordUpsert { + u.Add(taskrecord.FieldCodeLines, v) + return u +} + +// SetCode sets the "code" field. +func (u *TaskRecordUpsert) SetCode(v string) *TaskRecordUpsert { + u.Set(taskrecord.FieldCode, v) + return u +} + +// UpdateCode sets the "code" field to the value that was provided on create. +func (u *TaskRecordUpsert) UpdateCode() *TaskRecordUpsert { + u.SetExcluded(taskrecord.FieldCode) + return u +} + +// ClearCode clears the value of the "code" field. +func (u *TaskRecordUpsert) ClearCode() *TaskRecordUpsert { + u.SetNull(taskrecord.FieldCode) + return u +} + // SetCreatedAt sets the "created_at" field. func (u *TaskRecordUpsert) SetCreatedAt(v time.Time) *TaskRecordUpsert { u.Set(taskrecord.FieldCreatedAt, v) @@ -543,6 +610,48 @@ func (u *TaskRecordUpsertOne) UpdateOutputTokens() *TaskRecordUpsertOne { }) } +// SetCodeLines sets the "code_lines" field. +func (u *TaskRecordUpsertOne) SetCodeLines(v int64) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCodeLines(v) + }) +} + +// AddCodeLines adds v to the "code_lines" field. +func (u *TaskRecordUpsertOne) AddCodeLines(v int64) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.AddCodeLines(v) + }) +} + +// UpdateCodeLines sets the "code_lines" field to the value that was provided on create. +func (u *TaskRecordUpsertOne) UpdateCodeLines() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCodeLines() + }) +} + +// SetCode sets the "code" field. +func (u *TaskRecordUpsertOne) SetCode(v string) *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCode(v) + }) +} + +// UpdateCode sets the "code" field to the value that was provided on create. +func (u *TaskRecordUpsertOne) UpdateCode() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCode() + }) +} + +// ClearCode clears the value of the "code" field. +func (u *TaskRecordUpsertOne) ClearCode() *TaskRecordUpsertOne { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearCode() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskRecordUpsertOne) SetCreatedAt(v time.Time) *TaskRecordUpsertOne { return u.Update(func(s *TaskRecordUpsert) { @@ -877,6 +986,48 @@ func (u *TaskRecordUpsertBulk) UpdateOutputTokens() *TaskRecordUpsertBulk { }) } +// SetCodeLines sets the "code_lines" field. +func (u *TaskRecordUpsertBulk) SetCodeLines(v int64) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCodeLines(v) + }) +} + +// AddCodeLines adds v to the "code_lines" field. +func (u *TaskRecordUpsertBulk) AddCodeLines(v int64) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.AddCodeLines(v) + }) +} + +// UpdateCodeLines sets the "code_lines" field to the value that was provided on create. +func (u *TaskRecordUpsertBulk) UpdateCodeLines() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCodeLines() + }) +} + +// SetCode sets the "code" field. +func (u *TaskRecordUpsertBulk) SetCode(v string) *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.SetCode(v) + }) +} + +// UpdateCode sets the "code" field to the value that was provided on create. +func (u *TaskRecordUpsertBulk) UpdateCode() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.UpdateCode() + }) +} + +// ClearCode clears the value of the "code" field. +func (u *TaskRecordUpsertBulk) ClearCode() *TaskRecordUpsertBulk { + return u.Update(func(s *TaskRecordUpsert) { + s.ClearCode() + }) +} + // SetCreatedAt sets the "created_at" field. func (u *TaskRecordUpsertBulk) SetCreatedAt(v time.Time) *TaskRecordUpsertBulk { return u.Update(func(s *TaskRecordUpsert) { diff --git a/backend/db/taskrecord_update.go b/backend/db/taskrecord_update.go index 4f16d34..1c5be23 100644 --- a/backend/db/taskrecord_update.go +++ b/backend/db/taskrecord_update.go @@ -121,6 +121,47 @@ func (tru *TaskRecordUpdate) AddOutputTokens(i int64) *TaskRecordUpdate { return tru } +// SetCodeLines sets the "code_lines" field. +func (tru *TaskRecordUpdate) SetCodeLines(i int64) *TaskRecordUpdate { + tru.mutation.ResetCodeLines() + tru.mutation.SetCodeLines(i) + return tru +} + +// SetNillableCodeLines sets the "code_lines" field if the given value is not nil. +func (tru *TaskRecordUpdate) SetNillableCodeLines(i *int64) *TaskRecordUpdate { + if i != nil { + tru.SetCodeLines(*i) + } + return tru +} + +// AddCodeLines adds i to the "code_lines" field. +func (tru *TaskRecordUpdate) AddCodeLines(i int64) *TaskRecordUpdate { + tru.mutation.AddCodeLines(i) + return tru +} + +// SetCode sets the "code" field. +func (tru *TaskRecordUpdate) SetCode(s string) *TaskRecordUpdate { + tru.mutation.SetCode(s) + return tru +} + +// SetNillableCode sets the "code" field if the given value is not nil. +func (tru *TaskRecordUpdate) SetNillableCode(s *string) *TaskRecordUpdate { + if s != nil { + tru.SetCode(*s) + } + return tru +} + +// ClearCode clears the value of the "code" field. +func (tru *TaskRecordUpdate) ClearCode() *TaskRecordUpdate { + tru.mutation.ClearCode() + return tru +} + // SetCreatedAt sets the "created_at" field. func (tru *TaskRecordUpdate) SetCreatedAt(t time.Time) *TaskRecordUpdate { tru.mutation.SetCreatedAt(t) @@ -226,6 +267,18 @@ func (tru *TaskRecordUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := tru.mutation.AddedOutputTokens(); ok { _spec.AddField(taskrecord.FieldOutputTokens, field.TypeInt64, value) } + if value, ok := tru.mutation.CodeLines(); ok { + _spec.SetField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := tru.mutation.AddedCodeLines(); ok { + _spec.AddField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := tru.mutation.Code(); ok { + _spec.SetField(taskrecord.FieldCode, field.TypeString, value) + } + if tru.mutation.CodeCleared() { + _spec.ClearField(taskrecord.FieldCode, field.TypeString) + } if value, ok := tru.mutation.CreatedAt(); ok { _spec.SetField(taskrecord.FieldCreatedAt, field.TypeTime, value) } @@ -372,6 +425,47 @@ func (truo *TaskRecordUpdateOne) AddOutputTokens(i int64) *TaskRecordUpdateOne { return truo } +// SetCodeLines sets the "code_lines" field. +func (truo *TaskRecordUpdateOne) SetCodeLines(i int64) *TaskRecordUpdateOne { + truo.mutation.ResetCodeLines() + truo.mutation.SetCodeLines(i) + return truo +} + +// SetNillableCodeLines sets the "code_lines" field if the given value is not nil. +func (truo *TaskRecordUpdateOne) SetNillableCodeLines(i *int64) *TaskRecordUpdateOne { + if i != nil { + truo.SetCodeLines(*i) + } + return truo +} + +// AddCodeLines adds i to the "code_lines" field. +func (truo *TaskRecordUpdateOne) AddCodeLines(i int64) *TaskRecordUpdateOne { + truo.mutation.AddCodeLines(i) + return truo +} + +// SetCode sets the "code" field. +func (truo *TaskRecordUpdateOne) SetCode(s string) *TaskRecordUpdateOne { + truo.mutation.SetCode(s) + return truo +} + +// SetNillableCode sets the "code" field if the given value is not nil. +func (truo *TaskRecordUpdateOne) SetNillableCode(s *string) *TaskRecordUpdateOne { + if s != nil { + truo.SetCode(*s) + } + return truo +} + +// ClearCode clears the value of the "code" field. +func (truo *TaskRecordUpdateOne) ClearCode() *TaskRecordUpdateOne { + truo.mutation.ClearCode() + return truo +} + // SetCreatedAt sets the "created_at" field. func (truo *TaskRecordUpdateOne) SetCreatedAt(t time.Time) *TaskRecordUpdateOne { truo.mutation.SetCreatedAt(t) @@ -507,6 +601,18 @@ func (truo *TaskRecordUpdateOne) sqlSave(ctx context.Context) (_node *TaskRecord if value, ok := truo.mutation.AddedOutputTokens(); ok { _spec.AddField(taskrecord.FieldOutputTokens, field.TypeInt64, value) } + if value, ok := truo.mutation.CodeLines(); ok { + _spec.SetField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := truo.mutation.AddedCodeLines(); ok { + _spec.AddField(taskrecord.FieldCodeLines, field.TypeInt64, value) + } + if value, ok := truo.mutation.Code(); ok { + _spec.SetField(taskrecord.FieldCode, field.TypeString, value) + } + if truo.mutation.CodeCleared() { + _spec.ClearField(taskrecord.FieldCode, field.TypeString) + } if value, ok := truo.mutation.CreatedAt(); ok { _spec.SetField(taskrecord.FieldCreatedAt, field.TypeTime, value) } diff --git a/backend/docs/swagger.json b/backend/docs/swagger.json index 95e1999..eac0ad2 100644 --- a/backend/docs/swagger.json +++ b/backend/docs/swagger.json @@ -434,6 +434,12 @@ "description": "每页多少条记录", "name": "size", "in": "query" + }, + { + "type": "string", + "description": "工作模式", + "name": "work_mode", + "in": "query" } ], "responses": { @@ -553,6 +559,12 @@ "description": "每页多少条记录", "name": "size", "in": "query" + }, + { + "type": "string", + "description": "工作模式", + "name": "work_mode", + "in": "query" } ], "responses": { @@ -2033,6 +2045,41 @@ } } } + }, + "/v1/report": { + "post": { + "description": "报告", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "OpenAIV1" + ], + "summary": "报告", + "operationId": "report", + "parameters": [ + { + "description": "报告请求", + "name": "param", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/domain.ReportReq" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/web.Resp" + } + } + } + } } }, "definitions": { @@ -2051,11 +2098,13 @@ "type": "string", "enum": [ "user", - "assistant" + "assistant", + "system" ], "x-enum-varnames": [ "ChatRoleUser", - "ChatRoleAssistant" + "ChatRoleAssistant", + "ChatRoleSystem" ] }, "consts.ModelProvider": { @@ -2113,6 +2162,19 @@ "ModelTypeReranker" ] }, + "consts.ReportAction": { + "type": "string", + "enum": [ + "accept", + "suggest", + "file_written" + ], + "x-enum-varnames": [ + "ReportActionAccept", + "ReportActionSuggest", + "ReportActionFileWritten" + ] + }, "consts.UserPlatform": { "type": "string", "enum": [ @@ -3113,6 +3175,26 @@ } } }, + "domain.ReportReq": { + "type": "object", + "properties": { + "action": { + "$ref": "#/definitions/consts.ReportAction" + }, + "content": { + "description": "内容", + "type": "string" + }, + "id": { + "description": "task_id or resp_id", + "type": "string" + }, + "tool": { + "description": "工具", + "type": "string" + } + } + }, "domain.Setting": { "type": "object", "properties": { @@ -3474,6 +3556,14 @@ "description": "代码行数", "type": "integer" }, + "user": { + "description": "用户信息", + "allOf": [ + { + "$ref": "#/definitions/domain.User" + } + ] + }, "username": { "description": "用户名", "type": "string" diff --git a/backend/domain/billing.go b/backend/domain/billing.go index 81de700..a3346df 100644 --- a/backend/domain/billing.go +++ b/backend/domain/billing.go @@ -28,6 +28,7 @@ type ListRecordReq struct { *web.Pagination Author string `json:"author" query:"author"` // 作者 Language string `json:"language" query:"language"` // 语言 + WorkMode string `json:"work_mode" query:"work_mode"` // 工作模式 IsAccept *bool `json:"is_accept" query:"is_accept"` // 是否接受筛选 } diff --git a/backend/domain/dashboard.go b/backend/domain/dashboard.go index 0a19360..99e1358 100644 --- a/backend/domain/dashboard.go +++ b/backend/domain/dashboard.go @@ -5,6 +5,7 @@ import ( "time" "github.com/chaitin/MonkeyCode/backend/db" + "github.com/chaitin/MonkeyCode/backend/pkg/cvt" ) type DashboardUsecase interface { @@ -62,6 +63,7 @@ type UserHeatmap struct { type UserCodeRank struct { Username string `json:"username"` // 用户名 Lines int64 `json:"lines"` // 代码行数 + User *User `json:"user"` // 用户信息 } func (u *UserCodeRank) From(d *db.Task) *UserCodeRank { @@ -70,6 +72,7 @@ func (u *UserCodeRank) From(d *db.Task) *UserCodeRank { } u.Username = d.Edges.User.Username u.Lines = d.CodeLines + u.User = cvt.From(d.Edges.User, &User{}) return u } diff --git a/backend/domain/proxy.go b/backend/domain/proxy.go index d344a98..24a5eb6 100644 --- a/backend/domain/proxy.go +++ b/backend/domain/proxy.go @@ -22,12 +22,14 @@ type ProxyUsecase interface { Record(ctx context.Context, record *RecordParam) error ValidateApiKey(ctx context.Context, key string) (*ApiKey, error) AcceptCompletion(ctx context.Context, req *AcceptCompletionReq) error + Report(ctx context.Context, req *ReportReq) error } type ProxyRepo interface { Record(ctx context.Context, record *RecordParam) error UpdateByTaskID(ctx context.Context, taskID string, fn func(*db.TaskUpdateOne)) error AcceptCompletion(ctx context.Context, req *AcceptCompletionReq) error + Report(ctx context.Context, req *ReportReq) error SelectModelWithLoadBalancing(modelName string, modelType consts.ModelType) (*db.Model, error) ValidateApiKey(ctx context.Context, key string) (*db.ApiKey, error) } @@ -42,6 +44,13 @@ type AcceptCompletionReq struct { Completion string `json:"completion"` // 补全内容 } +type ReportReq struct { + Action consts.ReportAction `json:"action"` + ID string `json:"id"` // task_id or resp_id + Content string `json:"content"` // 内容 + Tool string `json:"tool"` // 工具 +} + type RecordParam struct { RequestID string TaskID string @@ -57,6 +66,7 @@ type RecordParam struct { Completion string WorkMode string CodeLines int64 + Code string } func (r *RecordParam) Clone() *RecordParam { diff --git a/backend/ent/schema/task.go b/backend/ent/schema/task.go index 265385a..ae58d8f 100644 --- a/backend/ent/schema/task.go +++ b/backend/ent/schema/task.go @@ -41,6 +41,7 @@ func (Task) Fields() []ent.Field { field.Int64("code_lines").Optional(), field.Int64("input_tokens").Optional(), field.Int64("output_tokens").Optional(), + field.Bool("is_suggested").Default(false), field.Time("created_at").Default(time.Now), field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now), } diff --git a/backend/ent/schema/taskrecord.go b/backend/ent/schema/taskrecord.go index edb1f3e..206fac6 100644 --- a/backend/ent/schema/taskrecord.go +++ b/backend/ent/schema/taskrecord.go @@ -35,6 +35,8 @@ func (TaskRecord) Fields() []ent.Field { field.String("role").GoType(consts.ChatRole("")), field.String("completion"), field.Int64("output_tokens"), + field.Int64("code_lines"), + field.String("code").Optional(), field.Time("created_at").Default(time.Now), field.Time("updated_at").Default(time.Now).UpdateDefault(time.Now), } diff --git a/backend/go.mod b/backend/go.mod index e11226b..118ffea 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -20,6 +20,7 @@ require ( golang.org/x/crypto v0.39.0 golang.org/x/oauth2 v0.18.0 golang.org/x/text v0.26.0 + golang.org/x/time v0.11.0 ) require ( @@ -97,7 +98,6 @@ require ( golang.org/x/net v0.41.0 // indirect golang.org/x/sync v0.15.0 // indirect golang.org/x/sys v0.33.0 // indirect - golang.org/x/time v0.11.0 // indirect golang.org/x/tools v0.33.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/protobuf v1.34.2 // indirect diff --git a/backend/internal/billing/repo/billing.go b/backend/internal/billing/repo/billing.go index 058786d..0ea936f 100644 --- a/backend/internal/billing/repo/billing.go +++ b/backend/internal/billing/repo/billing.go @@ -28,6 +28,7 @@ func (b *BillingRepo) ChatInfo(ctx context.Context, id string) (*domain.ChatInfo record, err := b.db.Task.Query(). WithTaskRecords(func(trq *db.TaskRecordQuery) { trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc())) + trq.Where(taskrecord.RoleNEQ(consts.ChatRoleSystem)) }). Where(task.TaskID(id)). First(ctx) @@ -98,6 +99,10 @@ func filterTask(q *db.TaskQuery, req domain.ListRecordReq) { ) }) } + + if req.WorkMode != "" { + q.Where(task.WorkMode(req.WorkMode)) + } } // ListCompletionRecord implements domain.BillingRepo. @@ -110,6 +115,7 @@ func (b *BillingRepo) ListCompletionRecord(ctx context.Context, req domain.ListR trq.Order(taskrecord.ByCreatedAt(sql.OrderAsc())) }). Where(task.ModelType(consts.ModelTypeCoder)). + Where(task.IsSuggested(true)). Order(task.ByCreatedAt(sql.OrderDesc())) filterTask(q, req) diff --git a/backend/internal/dashboard/repo/dashboard.go b/backend/internal/dashboard/repo/dashboard.go index 823fcef..6aaf5b1 100644 --- a/backend/internal/dashboard/repo/dashboard.go +++ b/backend/internal/dashboard/repo/dashboard.go @@ -224,6 +224,7 @@ func (d *DashboardRepo) UserCodeRank(ctx context.Context, req domain.StatisticsF return &domain.UserCodeRank{ Username: m[v.UserID].Username, Lines: v.CodeLines, + User: cvt.From(m[v.UserID], &domain.User{}), } }), nil } diff --git a/backend/internal/model/usecase/model.go b/backend/internal/model/usecase/model.go index 62a3483..e72f9b8 100644 --- a/backend/internal/model/usecase/model.go +++ b/backend/internal/model/usecase/model.go @@ -69,7 +69,7 @@ func (m *ModelUsecase) Check(ctx context.Context, req *domain.CheckModelReq) (*d "MonkeyCode 是一个基于大模型的代码生成器,它可以根据用户的需求生成代码。", "MonkeyCode 是一个基于大模型的代码生成器,它可以根据用户的需求生成代码。", }, - "query": "PandaWiki", + "query": "MonkeyCode", } url = req.APIBase + "/rerank" } diff --git a/backend/internal/openai/handler/v1/v1.go b/backend/internal/openai/handler/v1/v1.go index b1aaefc..46208bf 100644 --- a/backend/internal/openai/handler/v1/v1.go +++ b/backend/internal/openai/handler/v1/v1.go @@ -50,6 +50,7 @@ func NewV1Handler( g := w.Group("/v1", middleware.Auth()) g.GET("/models", web.BaseHandler(h.ModelList)) g.POST("/completion/accept", web.BindHandler(h.AcceptCompletion), active.Active("user")) + g.POST("/report", web.BindHandler(h.Report), active.Active("user")) g.POST("/chat/completions", web.BaseHandler(h.ChatCompletion), active.Active("user")) g.POST("/completions", web.BaseHandler(h.Completions), active.Active("user")) g.POST("/embeddings", web.BaseHandler(h.Embeddings), active.Active("user")) @@ -96,6 +97,25 @@ func (h *V1Handler) AcceptCompletion(c *web.Context, req domain.AcceptCompletion return nil } +// Report 报告 +// +// @Tags OpenAIV1 +// @Summary 报告 +// @Description 报告 +// @ID report +// @Accept json +// @Produce json +// @Param param body domain.ReportReq true "报告请求" +// @Success 200 {object} web.Resp{} +// @Router /v1/report [post] +func (h *V1Handler) Report(c *web.Context, req domain.ReportReq) error { + h.logger.DebugContext(c.Request().Context(), "Report", slog.Any("req", req)) + if err := h.proxyUse.Report(c.Request().Context(), &req); err != nil { + return err + } + return c.Success(nil) +} + // ModelList 模型列表 // // @Tags OpenAIV1 diff --git a/backend/internal/proxy/recorder.go b/backend/internal/proxy/recorder.go index 55e3e6a..8023835 100644 --- a/backend/internal/proxy/recorder.go +++ b/backend/internal/proxy/recorder.go @@ -16,6 +16,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg/diff" "github.com/chaitin/MonkeyCode/backend/pkg/promptparser" ) @@ -66,7 +67,7 @@ func (r *Recorder) handleShadow() { } var ( - taskID, mode, prompt, language string + taskID, mode, prompt, language, tool, code string ) switch r.ctx.Model.ModelType { @@ -76,9 +77,11 @@ func (r *Recorder) handleShadow() { r.logger.WarnContext(r.ctx.ctx, "unmarshal chat completion request failed", "error", err) return } - prompt = r.getPrompt(r.ctx.ctx, &req) + prompt = req.Metadata["prompt"] taskID = req.Metadata["task_id"] mode = req.Metadata["mode"] + tool = req.Metadata["tool"] + code = req.Metadata["code"] case consts.ModelTypeCoder: var req domain.CompletionRequest @@ -108,19 +111,22 @@ func (r *Recorder) handleShadow() { WorkMode: mode, Prompt: prompt, ProgramLanguage: language, - Role: consts.ChatRoleUser, + Role: consts.ChatRoleAssistant, + } + + switch tool { + case "appliedDiff", "editedExistingFile": + lines := diff.ParseConflictsAndCountLines(code) + for _, line := range lines { + rc.CodeLines += int64(line) + } + case "newFileCreated": + rc.CodeLines = int64(strings.Count(code, "\n")) } - var assistantRc *domain.RecordParam ct := r.ctx.RespHeader.Get("Content-Type") if strings.Contains(ct, "stream") { r.handleStream(rc) - if r.ctx.Model.ModelType == consts.ModelTypeLLM { - assistantRc = rc.Clone() - assistantRc.Role = consts.ChatRoleAssistant - rc.Completion = "" - rc.OutputTokens = 0 - } } else { r.handleJson(rc) } @@ -130,15 +136,20 @@ func (r *Recorder) handleShadow() { With("resp_header", formatHeader(r.ctx.RespHeader)). DebugContext(r.ctx.ctx, "handle shadow", "rc", rc) - if err := r.usecase.Record(context.Background(), rc); err != nil { - r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) - } - - if assistantRc != nil { - if err := r.usecase.Record(context.Background(), assistantRc); err != nil { + // 记录用户的提问 + if r.ctx.Model.ModelType == consts.ModelTypeLLM && prompt != "" { + tmp := rc.Clone() + tmp.Role = consts.ChatRoleUser + tmp.Completion = "" + tmp.OutputTokens = 0 + if err := r.usecase.Record(context.Background(), tmp); err != nil { r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) } } + + if err := r.usecase.Record(context.Background(), rc); err != nil { + r.logger.WarnContext(r.ctx.ctx, "记录请求失败", "error", err) + } } func (r *Recorder) writeMeta(body []byte) { diff --git a/backend/internal/proxy/repo/proxy.go b/backend/internal/proxy/repo/proxy.go index de51136..f31235b 100644 --- a/backend/internal/proxy/repo/proxy.go +++ b/backend/internal/proxy/repo/proxy.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "strings" "time" "github.com/google/uuid" @@ -16,6 +17,7 @@ import ( "github.com/chaitin/MonkeyCode/backend/db/task" "github.com/chaitin/MonkeyCode/backend/db/taskrecord" "github.com/chaitin/MonkeyCode/backend/domain" + "github.com/chaitin/MonkeyCode/backend/pkg/diff" "github.com/chaitin/MonkeyCode/backend/pkg/entx" ) @@ -113,6 +115,9 @@ func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) erro if t.InputTokens == 0 && record.InputTokens > 0 { up.SetInputTokens(record.InputTokens) } + if t.CodeLines > 0 { + up.AddCodeLines(record.CodeLines) + } if t.RequestID != record.RequestID { up.SetRequestID(record.RequestID) up.AddInputTokens(record.InputTokens) @@ -144,6 +149,8 @@ func (r *ProxyRepo) Record(ctx context.Context, record *domain.RecordParam) erro SetPrompt(record.Prompt). SetCompletion(record.Completion). SetOutputTokens(record.OutputTokens). + SetCodeLines(record.CodeLines). + SetCode(record.Code). Save(ctx) return err @@ -182,3 +189,88 @@ func (r *ProxyRepo) AcceptCompletion(ctx context.Context, req *domain.AcceptComp SetCompletion(req.Completion).Exec(ctx) }) } + +func (r *ProxyRepo) Report(ctx context.Context, req *domain.ReportReq) error { + return entx.WithTx(ctx, r.db, func(tx *db.Tx) error { + rc, err := tx.Task.Query().Where(task.TaskID(req.ID)).Only(ctx) + if err != nil { + return err + } + + switch req.Action { + case consts.ReportActionAccept: + if err := tx.Task.UpdateOneID(rc.ID). + SetIsAccept(true). + SetCompletion(req.Content). + Exec(ctx); err != nil { + return err + } + + return tx.TaskRecord.Update(). + Where(taskrecord.TaskID(rc.ID)). + SetCompletion(req.Content).Exec(ctx) + + case consts.ReportActionSuggest: + if err := tx.Task.UpdateOneID(rc.ID). + SetIsSuggested(true). + SetCompletion(req.Content). + Exec(ctx); err != nil { + return err + } + + return tx.TaskRecord.Update(). + Where(taskrecord.TaskID(rc.ID)). + SetCompletion(req.Content).Exec(ctx) + + case consts.ReportActionFileWritten: + if err := r.handleFileWritten(ctx, tx, rc, req); err != nil { + return err + } + } + + return nil + }) +} + +func (r *ProxyRepo) handleFileWritten(ctx context.Context, tx *db.Tx, rc *db.Task, req *domain.ReportReq) error { + lineCount := 0 + switch req.Tool { + case "appliedDiff", "editedExistingFile", "insertContent": + if strings.Contains(req.Content, "<<<<<<<") { + lines := diff.ParseConflictsAndCountLines(req.Content) + for _, line := range lines { + lineCount += line + } + } else { + lineCount = strings.Count(req.Content, "\n") + } + case "newFileCreated": + lineCount = strings.Count(req.Content, "\n") + } + + if lineCount > 0 { + if err := tx.Task. + UpdateOneID(rc.ID). + AddCodeLines(int64(lineCount)). + SetIsAccept(true). + Exec(ctx); err != nil { + return err + } + } + + if req.Content != "" { + if _, err := tx.TaskRecord.Create(). + SetTaskID(rc.ID). + SetRole(consts.ChatRoleSystem). + SetPrompt("写入文件"). + SetCompletion(""). + SetCodeLines(int64(lineCount)). + SetCode(req.Content). + SetOutputTokens(0). + Save(ctx); err != nil { + return err + } + } + + return nil +} diff --git a/backend/internal/proxy/usecase/proxy.go b/backend/internal/proxy/usecase/proxy.go index 065e506..65dcc7b 100644 --- a/backend/internal/proxy/usecase/proxy.go +++ b/backend/internal/proxy/usecase/proxy.go @@ -42,3 +42,7 @@ func (p *ProxyUsecase) ValidateApiKey(ctx context.Context, key string) (*domain. func (p *ProxyUsecase) AcceptCompletion(ctx context.Context, req *domain.AcceptCompletionReq) error { return p.repo.AcceptCompletion(ctx, req) } + +func (p *ProxyUsecase) Report(ctx context.Context, req *domain.ReportReq) error { + return p.repo.Report(ctx, req) +} diff --git a/backend/internal/user/handler/v1/user.go b/backend/internal/user/handler/v1/user.go index 804a903..019c894 100644 --- a/backend/internal/user/handler/v1/user.go +++ b/backend/internal/user/handler/v1/user.go @@ -11,6 +11,7 @@ import ( "time" "github.com/GoYoko/web" + "golang.org/x/time/rate" "github.com/chaitin/MonkeyCode/backend/config" "github.com/chaitin/MonkeyCode/backend/consts" @@ -33,9 +34,9 @@ type UserHandler struct { session *session.Session logger *slog.Logger cfg *config.Config - limitCh chan struct{} - vsixCache map[string]*CacheEntry // 缓存处理后的vsix文件 - cacheMu sync.RWMutex // 缓存读写锁 + vsixCache map[string]*CacheEntry + cacheMu sync.RWMutex + limiter *rate.Limiter } func NewUserHandler( @@ -54,8 +55,8 @@ func NewUserHandler( logger: logger, cfg: cfg, euse: euse, - limitCh: make(chan struct{}, cfg.Extension.Limit), vsixCache: make(map[string]*CacheEntry), + limiter: rate.NewLimiter(rate.Every(time.Duration(cfg.Extension.LimitSecond)*time.Second), cfg.Extension.Limit), } w.GET("/api/v1/static/vsix/:version", web.BaseHandler(u.VSIXDownload)) @@ -130,27 +131,22 @@ func (h *UserHandler) cleanExpiredCache() { // @Produce octet-stream // @Router /api/v1/static/vsix [get] func (h *UserHandler) VSIXDownload(c *web.Context) error { - h.limitCh <- struct{}{} - defer func() { - <-h.limitCh - }() + if !h.limiter.Allow() { + return c.String(http.StatusTooManyRequests, "Too Many Requests") + } v, err := h.euse.GetByVersion(c.Request().Context(), c.Param("version")) if err != nil { return err } - // 生成缓存键 cacheKey := h.generateCacheKey(v.Version, h.cfg.BaseUrl) - // 先检查缓存 h.cacheMu.RLock() if entry, exists := h.vsixCache[cacheKey]; exists { - // 检查缓存是否过期(1小时) if time.Since(entry.createdAt) < time.Hour { h.cacheMu.RUnlock() - // 从缓存返回数据 disposition := fmt.Sprintf("attachment; filename=monkeycode-%s.vsix", v.Version) c.Response().Header().Set("Content-Type", "application/octet-stream") c.Response().Header().Set("Content-Disposition", disposition) @@ -162,15 +158,11 @@ func (h *UserHandler) VSIXDownload(c *web.Context) error { } h.cacheMu.RUnlock() - // 缓存未命中或已过期,需要重新生成 - - // 使用buffer来捕获生成的数据 var buf bytes.Buffer if err := vsix.ChangeVsixEndpoint(v.Path, "extension/package.json", h.cfg.BaseUrl, &buf); err != nil { return err } - // 将结果存入缓存 data := buf.Bytes() h.cacheMu.Lock() h.vsixCache[cacheKey] = &CacheEntry{ @@ -179,10 +171,8 @@ func (h *UserHandler) VSIXDownload(c *web.Context) error { } h.cacheMu.Unlock() - // 异步清理过期缓存 go h.cleanExpiredCache() - // 返回数据给客户端 disposition := fmt.Sprintf("attachment; filename=monkeycode-%s.vsix", v.Version) c.Response().Header().Set("Content-Type", "application/octet-stream") c.Response().Header().Set("Content-Disposition", disposition) diff --git a/backend/migration/000009_alter_task_table.down.sql b/backend/migration/000009_alter_task_table.down.sql new file mode 100644 index 0000000..e69de29 diff --git a/backend/migration/000009_alter_task_table.up.sql b/backend/migration/000009_alter_task_table.up.sql new file mode 100644 index 0000000..9ad72e9 --- /dev/null +++ b/backend/migration/000009_alter_task_table.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE tasks ADD column is_suggested boolean default false; +ALTER TABLE task_records ADD column code_lines int default 0; +ALTER TABLE task_records ADD column code text; \ No newline at end of file diff --git a/backend/pkg/diff/diff.go b/backend/pkg/diff/diff.go new file mode 100644 index 0000000..769615a --- /dev/null +++ b/backend/pkg/diff/diff.go @@ -0,0 +1,157 @@ +package diff + +import ( + "strings" +) + +type ConflictBlock struct { + OursContent []string + TheirsContent []string + StartLine int + EndLine int +} + +type ConflictParser struct { + lines []string + currentLine int + conflicts []ConflictBlock +} + +func NewConflictParser(text string) *ConflictParser { + return &ConflictParser{ + lines: strings.Split(text, "\n"), + } +} + +func (p *ConflictParser) ParseConflicts() []ConflictBlock { + for p.currentLine < len(p.lines) { + line := p.lines[p.currentLine] + + if p.isConflictStart(line) { + conflict := p.parseConflictBlock() + if conflict != nil { + p.conflicts = append(p.conflicts, *conflict) + } + } else { + p.currentLine++ + } + } + return p.conflicts +} + +func (p *ConflictParser) isConflictStart(line string) bool { + trimmed := strings.TrimSpace(line) + if len(trimmed) < 7 { + return false + } + + for i := range 7 { + if trimmed[i] != '<' { + return false + } + } + return true +} + +func (p *ConflictParser) isConflictSeparator(line string) bool { + trimmed := strings.TrimSpace(line) + if len(trimmed) != 7 { + return false + } + + for i := range 7 { + if trimmed[i] != '=' { + return false + } + } + return true +} + +func (p *ConflictParser) isConflictEnd(line string) bool { + trimmed := strings.TrimSpace(line) + if len(trimmed) < 7 { + return false + } + + for i := range 7 { + if trimmed[i] != '>' { + return false + } + } + return true +} + +func (p *ConflictParser) parseConflictBlock() *ConflictBlock { + startLine := p.currentLine + p.currentLine++ + + conflict := &ConflictBlock{ + StartLine: startLine, + } + + for p.currentLine < len(p.lines) { + line := p.lines[p.currentLine] + + if p.isConflictSeparator(line) { + p.currentLine++ + break + } + + conflict.OursContent = append(conflict.OursContent, line) + p.currentLine++ + } + + for p.currentLine < len(p.lines) { + line := p.lines[p.currentLine] + + if p.isConflictEnd(line) { + conflict.EndLine = p.currentLine + p.currentLine++ + return conflict + } + + conflict.TheirsContent = append(conflict.TheirsContent, line) + p.currentLine++ + } + + return nil +} + +func CountAddedLines(text string) int { + lines := strings.Split(text, "\n") + addedLines := 0 + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed != "" && !strings.HasPrefix(trimmed, "//") && !strings.HasPrefix(trimmed, "#") && !strings.HasPrefix(trimmed, "/*") { + addedLines++ + } + } + + return addedLines +} + +func (cb *ConflictBlock) CountAddedLinesInConflict() int { + theirsText := strings.Join(cb.TheirsContent, "\n") + return CountAddedLines(theirsText) +} + +func (cb *ConflictBlock) CountNetAddedLines() int { + oursText := strings.Join(cb.OursContent, "\n") + theirsText := strings.Join(cb.TheirsContent, "\n") + oursLines := CountAddedLines(oursText) + theirsLines := CountAddedLines(theirsText) + return theirsLines - oursLines +} + +func ParseConflictsAndCountLines(text string) []int { + parser := NewConflictParser(text) + conflicts := parser.ParseConflicts() + + var addedLines []int + for _, conflict := range conflicts { + addedLines = append(addedLines, conflict.CountAddedLinesInConflict()) + } + + return addedLines +} diff --git a/backend/pkg/diff/diff_test.go b/backend/pkg/diff/diff_test.go new file mode 100644 index 0000000..d4ef4a7 --- /dev/null +++ b/backend/pkg/diff/diff_test.go @@ -0,0 +1,25 @@ +package diff + +import ( + "fmt" + "testing" +) + +func TestParseConflictsAndCountLines(t *testing.T) { + conflictText := `<<<<<<< HEAD +old line 1 +======= +new line 1 +new line 2 +>>>>>>> branch1 +normal line +<<<<<<< HEAD +old line 2 +old line 3 +======= +new line 3 +>>>>>>> branch2` + + addedLines := ParseConflictsAndCountLines(conflictText) + fmt.Println(addedLines) +} diff --git a/ui/src/api/OpenAiv1.ts b/ui/src/api/OpenAiv1.ts index 3a61f16..52ad1e6 100644 --- a/ui/src/api/OpenAiv1.ts +++ b/ui/src/api/OpenAiv1.ts @@ -14,6 +14,7 @@ import request, { ContentType, RequestParams } from "./httpClient"; import { DomainAcceptCompletionReq, DomainModelListResp, + DomainReportReq, WebResp, } from "./types"; @@ -122,3 +123,26 @@ export const getModelList = (params: RequestParams = {}) => format: "json", ...params, }); + +/** + * @description 报告 + * + * @tags OpenAIV1 + * @name PostReport + * @summary 报告 + * @request POST:/v1/report + * @response `200` `WebResp` OK + */ + +export const postReport = ( + param: DomainReportReq, + params: RequestParams = {}, +) => + request({ + path: `/v1/report`, + method: "POST", + body: param, + type: ContentType.Json, + format: "json", + ...params, + }); diff --git a/ui/src/api/types.ts b/ui/src/api/types.ts index 0a4e9c4..d4e769e 100644 --- a/ui/src/api/types.ts +++ b/ui/src/api/types.ts @@ -22,6 +22,12 @@ export enum ConstsUserPlatform { UserPlatformCustom = "custom", } +export enum ConstsReportAction { + ReportActionAccept = "accept", + ReportActionSuggest = "suggest", + ReportActionFileWritten = "file_written", +} + export enum ConstsModelType { ModelTypeLLM = "llm", ModelTypeCoder = "coder", @@ -51,6 +57,7 @@ export enum ConstsModelProvider { export enum ConstsChatRole { ChatRoleUser = "user", ChatRoleAssistant = "assistant", + ChatRoleSystem = "system", } export enum ConstsAdminStatus { @@ -480,6 +487,16 @@ export interface DomainRegisterReq { username: string; } +export interface DomainReportReq { + action?: ConstsReportAction; + /** 内容 */ + content?: string; + /** task_id or resp_id */ + id?: string; + /** 工具 */ + tool?: string; +} + export interface DomainSetting { /** 创建时间 */ created_at?: number; @@ -639,6 +656,8 @@ export interface DomainUser { export interface DomainUserCodeRank { /** 代码行数 */ lines?: number; + /** 用户信息 */ + user?: DomainUser; /** 用户名 */ username?: string; } @@ -767,6 +786,8 @@ export interface GetListChatRecordParams { page?: number; /** 每页多少条记录 */ size?: number; + /** 工作模式 */ + work_mode?: string; } export interface GetCompletionInfoParams { @@ -787,6 +808,8 @@ export interface GetListCompletionRecordParams { page?: number; /** 每页多少条记录 */ size?: number; + /** 工作模式 */ + work_mode?: string; } export interface GetCategoryStatDashboardParams { diff --git a/ui/src/pages/chat/index.tsx b/ui/src/pages/chat/index.tsx index d1759f1..d54b46f 100644 --- a/ui/src/pages/chat/index.tsx +++ b/ui/src/pages/chat/index.tsx @@ -4,7 +4,7 @@ import { getListChatRecord } from '@/api/Billing'; import dayjs from 'dayjs'; import Card from '@/components/card'; -import { Box } from '@mui/material'; +import { Autocomplete, Box, FormControl, InputLabel, MenuItem, Select, Stack, TextField } from '@mui/material'; import StyledLabel from '@/components/label'; import ChatDetailModal from './chatDetailModal'; @@ -12,6 +12,9 @@ import { ColumnsType } from '@c-x/ui/dist/Table'; import { DomainChatRecord, DomainUser } from '@/api/types'; import { addCommasToNumber } from '@/utils'; import User from '@/components/user'; +import { useRequest } from 'ahooks'; +import { getListUser } from '@/api/User'; +import { set } from 'react-hook-form'; const Chat = () => { const [page, setPage] = useState(1); @@ -22,11 +25,25 @@ const Chat = () => { const [chatDetailModal, setChatDetailModal] = useState< DomainChatRecord | undefined >(); + const [filterUser, setFilterUser] = useState(''); + const [filterMode, setfilterMode] = useState< + 'code' | 'architect' | 'ask' | 'debug' | 'orchestrator' + >(); + + const { data: userOptions = { users: [] } } = useRequest(() => + getListUser({ + page: 1, + size: 9999, + }) + ); + const fetchData = async () => { setLoading(true); const res = await getListChatRecord({ page: page, size: size, + work_mode: filterMode, + author: filterUser, }); setLoading(false); setTotal(res?.total_count || 0); @@ -34,9 +51,10 @@ const Chat = () => { }; useEffect(() => { + setPage(1); fetchData(); // eslint-disable-next-line - }, [page, size]); + }, [page, size, filterMode, filterUser]); const columns: ColumnsType = [ { @@ -137,6 +155,43 @@ const Chat = () => { ]; return ( + + option.username || ''} + value={ + userOptions.users?.find((item) => item.username === filterUser) || + null + } + onChange={(_, newValue) => + setFilterUser(newValue ? newValue.username! : '') + } + isOptionEqualToValue={(option, value) => + option.username === value.username + } + renderInput={(params) => } + clearOnEscape + /> + + 工作模式 + + + - {/* */} - - {item.username} + + + {item.username} - {item.lines} + {item.lines} 行 ))} diff --git a/ui/src/pages/dashboard/components/statisticCard.tsx b/ui/src/pages/dashboard/components/statisticCard.tsx index 5c64809..420fab5 100644 --- a/ui/src/pages/dashboard/components/statisticCard.tsx +++ b/ui/src/pages/dashboard/components/statisticCard.tsx @@ -1,5 +1,5 @@ -import React, { useState } from 'react'; -import { styled, Stack, Box, Button } from '@mui/material'; +import { useState } from 'react'; +import { styled, Stack, Box } from '@mui/material'; import { Empty } from '@c-x/ui'; import dayjs from 'dayjs'; import { useNavigate } from 'react-router-dom'; @@ -12,6 +12,7 @@ import { DomainUserCodeRank, DomainUserEvent, } from '@/api/types'; +import Avatar from '@/components/avatar'; const StyledCardLabel = styled('div')(({ theme }) => ({ fontSize: '14px', @@ -122,26 +123,19 @@ export const ContributionCard = ({ gap={1.5} sx={{ flex: 1, - minWidth: 0, - // cursor: 'pointer', - // '&:hover': { - // '.active-user-name': { - // color: 'primary.main', - // }, - // }, + minWidth: 0 }} - // onClick={() => { - // navigate(`/`) - // // window.open(`/discussion/user/${item.id}`); - // }} > - {/* */} - - {item.username} + + {item.username} - {item.lines} + {item.lines} 行 ))} diff --git a/ui/src/pages/user/chat/index.tsx b/ui/src/pages/user/chat/index.tsx index d1759f1..2df2649 100644 --- a/ui/src/pages/user/chat/index.tsx +++ b/ui/src/pages/user/chat/index.tsx @@ -34,6 +34,7 @@ const Chat = () => { }; useEffect(() => { + setPage(1); // 筛选变化时重置页码 fetchData(); // eslint-disable-next-line }, [page, size]); pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy