实现category service基础接口
This commit is contained in:
167
category/internal/logic/getsystemcategorieslogic.go
Normal file
167
category/internal/logic/getsystemcategorieslogic.go
Normal file
@ -0,0 +1,167 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"godemo/category/category"
|
||||
"godemo/category/internal/model"
|
||||
"godemo/category/internal/svc"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type GetSystemCategoriesLogic struct {
|
||||
ctx context.Context
|
||||
svcCtx *svc.ServiceContext
|
||||
logx.Logger
|
||||
}
|
||||
|
||||
func NewGetSystemCategoriesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetSystemCategoriesLogic {
|
||||
return &GetSystemCategoriesLogic{
|
||||
ctx: ctx,
|
||||
svcCtx: svcCtx,
|
||||
Logger: logx.WithContext(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
// 根据系统ID获取分类
|
||||
func (l *GetSystemCategoriesLogic) GetSystemCategories(in *category.GetSystemCategoriesRequest) (*category.GetSystemCategoriesResponse, error) {
|
||||
// 1. 参数验证
|
||||
if in.SystemId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "system_id is required")
|
||||
}
|
||||
|
||||
// 2. 设置默认分页值
|
||||
page, pageSize := normalizePagination(in.Page, in.PageSize)
|
||||
|
||||
// 3. 根据是否包含后代选择查询方式
|
||||
var categories []*model.Categories
|
||||
var total int64
|
||||
var err error
|
||||
|
||||
if in.IncludeDescendants {
|
||||
// 递归查询所有后代
|
||||
categories, err = l.getDescendantsRecursively(in.SystemId, in.ParentId)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to get descendant categories")
|
||||
}
|
||||
total = int64(len(categories))
|
||||
} else {
|
||||
// 直接查询子分类
|
||||
categories, total, err = l.getDirectChildren(in.SystemId, in.ParentId, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to get categories")
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 转换模型到protobuf
|
||||
pbCategories := make([]*category.CategoryInfo, 0, len(categories))
|
||||
for _, cat := range categories {
|
||||
pbCategories = append(pbCategories, convertToPbCategory(cat))
|
||||
}
|
||||
|
||||
// 5. 构建响应
|
||||
return &category.GetSystemCategoriesResponse{
|
||||
Categories: pbCategories,
|
||||
Total: total,
|
||||
CurrentPage: int32(page),
|
||||
PageSize: int32(pageSize),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 获取直接子分类(非递归)
|
||||
func (l *GetSystemCategoriesLogic) getDirectChildren(systemID, parentID string, page, pageSize int) ([]*model.Categories, int64, error) {
|
||||
// 构建基础查询
|
||||
query := l.svcCtx.CategoryModel.RowBuilder().
|
||||
Where("system_id = $1", systemID)
|
||||
|
||||
// 处理父分类ID
|
||||
if parentID == "" {
|
||||
query = query.Where("(parent_id IS NULL OR parent_id = '')")
|
||||
} else {
|
||||
query = query.Where("parent_id = $2", parentID)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := l.svcCtx.CategoryModel.FindCount(l.ctx, query.RemoveColumns())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
query = query.OrderBy("created_at DESC").
|
||||
Offset(uint64((page - 1) * pageSize)).
|
||||
Limit(uint64(pageSize))
|
||||
|
||||
categories, err := l.svcCtx.CategoryModel.FindAll(l.ctx, query)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return categories, total, nil
|
||||
}
|
||||
|
||||
// 递归获取所有后代分类
|
||||
func (l *GetSystemCategoriesLogic) getDescendantsRecursively(systemID, parentID string) ([]*model.Categories, error) {
|
||||
// 使用递归CTE查询
|
||||
query := `
|
||||
WITH RECURSIVE category_tree AS (
|
||||
SELECT id, system_id, name, alias, parent_id, description, created_at, updated_at
|
||||
FROM categories
|
||||
WHERE system_id = $1
|
||||
AND (
|
||||
CASE
|
||||
WHEN $2 = '' THEN parent_id IS NULL
|
||||
ELSE parent_id = $2
|
||||
END
|
||||
)
|
||||
UNION ALL
|
||||
SELECT c.id, c.system_id, c.name, c.alias, c.parent_id, c.description, c.created_at, c.updated_at
|
||||
FROM categories c
|
||||
INNER JOIN category_tree ct ON c.parent_id = ct.id
|
||||
)
|
||||
SELECT * FROM category_tree
|
||||
`
|
||||
|
||||
var categories []*model.Categories
|
||||
err := l.svcCtx.SqlConn.QueryRowsCtx(l.ctx, &categories, query, systemID, parentID)
|
||||
if err != nil && err != sqlx.ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// 转换模型到protobuf
|
||||
func convertToPbCategory(cat *model.Categories) *category.CategoryInfo {
|
||||
return &category.CategoryInfo{
|
||||
Id: cat.Id,
|
||||
SystemId: cat.SystemId,
|
||||
Name: cat.Name,
|
||||
Alias: cat.Alias.String,
|
||||
ParentId: cat.ParentId.String,
|
||||
Description: cat.Description.String,
|
||||
CreatedAt: cat.CreatedAt.Unix(),
|
||||
UpdatedAt: cat.UpdatedAt.Time.Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
// 规范化分页参数
|
||||
func normalizePagination(page, pageSize int32) (int, int) {
|
||||
p := int(page)
|
||||
ps := int(pageSize)
|
||||
|
||||
if p <= 0 {
|
||||
p = 1
|
||||
}
|
||||
if ps <= 0 {
|
||||
ps = 20
|
||||
} else if ps > 100 {
|
||||
ps = 100
|
||||
}
|
||||
|
||||
return p, ps
|
||||
}
|
||||
@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
@ -17,6 +18,11 @@ type (
|
||||
Transact(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error
|
||||
FindBySystemParentName(ctx context.Context, systemID string, parentID string, name string) (*Categories, error)
|
||||
FindBySystemParentAlias(ctx context.Context, systemID string, parentID string, alias string) (*Categories, error)
|
||||
|
||||
// 新增方法
|
||||
RowBuilder() squirrel.SelectBuilder
|
||||
FindCount(ctx context.Context, builder squirrel.SelectBuilder) (int64, error)
|
||||
FindAll(ctx context.Context, builder squirrel.SelectBuilder) ([]*Categories, error)
|
||||
}
|
||||
|
||||
customCategoriesModel struct {
|
||||
@ -72,3 +78,81 @@ func (m *customCategoriesModel) FindBySystemParentAlias(ctx context.Context, sys
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// === 新增方法实现 ===
|
||||
|
||||
// RowBuilder 创建一个基本的SELECT查询构建器
|
||||
func (m *customCategoriesModel) RowBuilder() squirrel.SelectBuilder {
|
||||
return squirrel.Select(categoriesRows).From(m.table)
|
||||
}
|
||||
|
||||
// FindCount 执行COUNT查询
|
||||
func (m *customCategoriesModel) FindCount(ctx context.Context, builder squirrel.SelectBuilder) (int64, error) {
|
||||
// 将SELECT转换为COUNT
|
||||
builder = builder.Columns("COUNT(1) AS count")
|
||||
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var count int64
|
||||
err = m.conn.QueryRowCtx(ctx, &count, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// FindAll 执行查询并返回所有结果
|
||||
func (m *customCategoriesModel) FindAll(ctx context.Context, builder squirrel.SelectBuilder) ([]*Categories, error) {
|
||||
query, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var categories []*Categories
|
||||
err = m.conn.QueryRowsCtx(ctx, &categories, query, args...)
|
||||
if err != nil && err != sqlx.ErrNotFound {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return categories, nil
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
|
||||
// 添加分页支持
|
||||
func (m *customCategoriesModel) FindAllWithPagination(
|
||||
ctx context.Context,
|
||||
builder squirrel.SelectBuilder,
|
||||
orderBy string,
|
||||
page int,
|
||||
pageSize int,
|
||||
) ([]*Categories, int64, error) {
|
||||
// 获取总数
|
||||
total, err := m.FindCount(ctx, builder)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 添加分页
|
||||
if page > 0 && pageSize > 0 {
|
||||
offset := (page - 1) * pageSize
|
||||
builder = builder.Offset(uint64(offset)).Limit(uint64(pageSize))
|
||||
}
|
||||
|
||||
// 添加排序
|
||||
if orderBy != "" {
|
||||
builder = builder.OrderBy(orderBy)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
categories, err := m.FindAll(ctx, builder)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return categories, total, nil
|
||||
}
|
||||
|
||||
@ -92,3 +92,9 @@ func (s *CategoryServer) CheckAlias(ctx context.Context, in *category.CheckAlias
|
||||
l := logic.NewCheckAliasLogic(ctx, s.svcCtx)
|
||||
return l.CheckAlias(in)
|
||||
}
|
||||
|
||||
// 根据系统ID获取分类
|
||||
func (s *CategoryServer) GetSystemCategories(ctx context.Context, in *category.GetSystemCategoriesRequest) (*category.GetSystemCategoriesResponse, error) {
|
||||
l := logic.NewGetSystemCategoriesLogic(ctx, s.svcCtx)
|
||||
return l.GetSystemCategories(in)
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
type ServiceContext struct {
|
||||
Config config.Config
|
||||
CategoryModel model.CategoriesModel
|
||||
SqlConn sqlx.SqlConn // 添加这个字段
|
||||
}
|
||||
|
||||
func NewServiceContext(c config.Config) *ServiceContext {
|
||||
@ -18,5 +19,6 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
||||
return &ServiceContext{
|
||||
Config: c,
|
||||
CategoryModel: model.NewCategoriesModel(conn),
|
||||
SqlConn: conn, // 赋值
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user