实现category service基础接口

This commit is contained in:
2025-05-31 16:27:01 +08:00
parent faa6a35475
commit e5446bf836
33 changed files with 1420 additions and 67 deletions

View File

@ -0,0 +1,167 @@
package logic
import (
"context"
"godemo/category/category"
"godemo/category/internal/model"
"godemo/category/internal/svc"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type GetSystemCategoriesLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewGetSystemCategoriesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetSystemCategoriesLogic {
return &GetSystemCategoriesLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// 根据系统ID获取分类
func (l *GetSystemCategoriesLogic) GetSystemCategories(in *category.GetSystemCategoriesRequest) (*category.GetSystemCategoriesResponse, error) {
// 1. 参数验证
if in.SystemId == "" {
return nil, status.Error(codes.InvalidArgument, "system_id is required")
}
// 2. 设置默认分页值
page, pageSize := normalizePagination(in.Page, in.PageSize)
// 3. 根据是否包含后代选择查询方式
var categories []*model.Categories
var total int64
var err error
if in.IncludeDescendants {
// 递归查询所有后代
categories, err = l.getDescendantsRecursively(in.SystemId, in.ParentId)
if err != nil {
return nil, status.Error(codes.Internal, "failed to get descendant categories")
}
total = int64(len(categories))
} else {
// 直接查询子分类
categories, total, err = l.getDirectChildren(in.SystemId, in.ParentId, page, pageSize)
if err != nil {
return nil, status.Error(codes.Internal, "failed to get categories")
}
}
// 4. 转换模型到protobuf
pbCategories := make([]*category.CategoryInfo, 0, len(categories))
for _, cat := range categories {
pbCategories = append(pbCategories, convertToPbCategory(cat))
}
// 5. 构建响应
return &category.GetSystemCategoriesResponse{
Categories: pbCategories,
Total: total,
CurrentPage: int32(page),
PageSize: int32(pageSize),
}, nil
}
// 获取直接子分类(非递归)
func (l *GetSystemCategoriesLogic) getDirectChildren(systemID, parentID string, page, pageSize int) ([]*model.Categories, int64, error) {
// 构建基础查询
query := l.svcCtx.CategoryModel.RowBuilder().
Where("system_id = $1", systemID)
// 处理父分类ID
if parentID == "" {
query = query.Where("(parent_id IS NULL OR parent_id = '')")
} else {
query = query.Where("parent_id = $2", parentID)
}
// 获取总数
total, err := l.svcCtx.CategoryModel.FindCount(l.ctx, query.RemoveColumns())
if err != nil {
return nil, 0, err
}
// 分页查询
query = query.OrderBy("created_at DESC").
Offset(uint64((page - 1) * pageSize)).
Limit(uint64(pageSize))
categories, err := l.svcCtx.CategoryModel.FindAll(l.ctx, query)
if err != nil {
return nil, 0, err
}
return categories, total, nil
}
// 递归获取所有后代分类
func (l *GetSystemCategoriesLogic) getDescendantsRecursively(systemID, parentID string) ([]*model.Categories, error) {
// 使用递归CTE查询
query := `
WITH RECURSIVE category_tree AS (
SELECT id, system_id, name, alias, parent_id, description, created_at, updated_at
FROM categories
WHERE system_id = $1
AND (
CASE
WHEN $2 = '' THEN parent_id IS NULL
ELSE parent_id = $2
END
)
UNION ALL
SELECT c.id, c.system_id, c.name, c.alias, c.parent_id, c.description, c.created_at, c.updated_at
FROM categories c
INNER JOIN category_tree ct ON c.parent_id = ct.id
)
SELECT * FROM category_tree
`
var categories []*model.Categories
err := l.svcCtx.SqlConn.QueryRowsCtx(l.ctx, &categories, query, systemID, parentID)
if err != nil && err != sqlx.ErrNotFound {
return nil, err
}
return categories, nil
}
// 转换模型到protobuf
func convertToPbCategory(cat *model.Categories) *category.CategoryInfo {
return &category.CategoryInfo{
Id: cat.Id,
SystemId: cat.SystemId,
Name: cat.Name,
Alias: cat.Alias.String,
ParentId: cat.ParentId.String,
Description: cat.Description.String,
CreatedAt: cat.CreatedAt.Unix(),
UpdatedAt: cat.UpdatedAt.Time.Unix(),
}
}
// 规范化分页参数
func normalizePagination(page, pageSize int32) (int, int) {
p := int(page)
ps := int(pageSize)
if p <= 0 {
p = 1
}
if ps <= 0 {
ps = 20
} else if ps > 100 {
ps = 100
}
return p, ps
}

View File

@ -3,6 +3,7 @@ package model
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/zeromicro/go-zero/core/stores/sqlx"
)
@ -17,6 +18,11 @@ type (
Transact(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error
FindBySystemParentName(ctx context.Context, systemID string, parentID string, name string) (*Categories, error)
FindBySystemParentAlias(ctx context.Context, systemID string, parentID string, alias string) (*Categories, error)
// 新增方法
RowBuilder() squirrel.SelectBuilder
FindCount(ctx context.Context, builder squirrel.SelectBuilder) (int64, error)
FindAll(ctx context.Context, builder squirrel.SelectBuilder) ([]*Categories, error)
}
customCategoriesModel struct {
@ -72,3 +78,81 @@ func (m *customCategoriesModel) FindBySystemParentAlias(ctx context.Context, sys
}
return &resp, nil
}
// === 新增方法实现 ===
// RowBuilder 创建一个基本的SELECT查询构建器
func (m *customCategoriesModel) RowBuilder() squirrel.SelectBuilder {
return squirrel.Select(categoriesRows).From(m.table)
}
// FindCount 执行COUNT查询
func (m *customCategoriesModel) FindCount(ctx context.Context, builder squirrel.SelectBuilder) (int64, error) {
// 将SELECT转换为COUNT
builder = builder.Columns("COUNT(1) AS count")
query, args, err := builder.ToSql()
if err != nil {
return 0, err
}
var count int64
err = m.conn.QueryRowCtx(ctx, &count, query, args...)
if err != nil {
return 0, err
}
return count, nil
}
// FindAll 执行查询并返回所有结果
func (m *customCategoriesModel) FindAll(ctx context.Context, builder squirrel.SelectBuilder) ([]*Categories, error) {
query, args, err := builder.ToSql()
if err != nil {
return nil, err
}
var categories []*Categories
err = m.conn.QueryRowsCtx(ctx, &categories, query, args...)
if err != nil && err != sqlx.ErrNotFound {
return nil, err
}
return categories, nil
}
// === 辅助函数 ===
// 添加分页支持
func (m *customCategoriesModel) FindAllWithPagination(
ctx context.Context,
builder squirrel.SelectBuilder,
orderBy string,
page int,
pageSize int,
) ([]*Categories, int64, error) {
// 获取总数
total, err := m.FindCount(ctx, builder)
if err != nil {
return nil, 0, err
}
// 添加分页
if page > 0 && pageSize > 0 {
offset := (page - 1) * pageSize
builder = builder.Offset(uint64(offset)).Limit(uint64(pageSize))
}
// 添加排序
if orderBy != "" {
builder = builder.OrderBy(orderBy)
}
// 执行查询
categories, err := m.FindAll(ctx, builder)
if err != nil {
return nil, 0, err
}
return categories, total, nil
}

View File

@ -92,3 +92,9 @@ func (s *CategoryServer) CheckAlias(ctx context.Context, in *category.CheckAlias
l := logic.NewCheckAliasLogic(ctx, s.svcCtx)
return l.CheckAlias(in)
}
// 根据系统ID获取分类
func (s *CategoryServer) GetSystemCategories(ctx context.Context, in *category.GetSystemCategoriesRequest) (*category.GetSystemCategoriesResponse, error) {
l := logic.NewGetSystemCategoriesLogic(ctx, s.svcCtx)
return l.GetSystemCategories(in)
}

View File

@ -11,6 +11,7 @@ import (
type ServiceContext struct {
Config config.Config
CategoryModel model.CategoriesModel
SqlConn sqlx.SqlConn // 添加这个字段
}
func NewServiceContext(c config.Config) *ServiceContext {
@ -18,5 +19,6 @@ func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{
Config: c,
CategoryModel: model.NewCategoriesModel(conn),
SqlConn: conn, // 赋值
}
}