深色模式
把 goctl 改为 ginctl
goctl 是什么
goctl是go-zero框架提供的命令行工具,用于一键生成模板代码。
go-zero适合微服务,而我只想使用gin开发单体服务,于是,把goctl的代码做一些修改,就能用来生成基于gin的单体服务项目的模板代码。
goctl 代码分析
入口文件
goctl.go
cmd/root.go
api/cmd.go
api/apigen/
:生成api模板文件,很少使用
api/gogen
:生成go项目,主要修改的地方
核心代码
位于api/gogen/gen.go
:
go
logx.Must(genEtc(dir, cfg, api))
logx.Must(genConfig(dir, cfg, api))
logx.Must(genMain(dir, rootPkg, cfg, api))
logx.Must(genServiceContext(dir, rootPkg, cfg, api))
logx.Must(genTypes(dir, cfg, api))
logx.Must(genRoutes(dir, rootPkg, cfg, api))
logx.Must(genResponse(dir, rootPkg, cfg, api))
logx.Must(genHandlers(dir, rootPkg, cfg, api))
logx.Must(genLogic(dir, rootPkg, cfg, api))
logx.Must(genMiddleware(dir, rootPkg, cfg, api))
作用是依次生成各种文件,以及相关代码。
命令的名字
命令的名字其实就是编译产物的文件名,原项目里面是goctl,它是go.mod
文件决定的。
要修改命令的名字,只需修改go.mod
文件:
go
module github.com/zeromicro/go-zero/tools/goctl
go
module ginctl
这样,命令的名字,从goctl
变成了ginctl
关于覆盖生成
覆盖生成的原理是:在生成文件之前,先删除旧的。
在项目中,有些地方走了覆盖逻辑,比如:
go
typeFilename = typeFilename + ".go"
filename := path.Join(dir, typesDir, typeFilename)
os.Remove(filename)
开始修改
genEtc
原来仅生成1个etc文件,我这里改为生成3个文件,对应不同的环境。
代码如下:
text
Host: {{.host}}
Port: {{.port}}
Jwt:
Secret: ""
ExpireHours: 24
MySQL:
User: ""
Password: ""
Host: ""
Port: ""
Database: ""
Redis:
Host: ""
Port: ""
Password: ""
go
package gogen
import (
_ "embed"
"fmt"
"strconv"
"ginctl/api/spec"
"ginctl/config"
)
const (
defaultPort = 8888
etcDir = "etc"
)
//go:embed etc.tpl
var etcTemplate string
func genEtc(dir string, cfg *config.Config, api *spec.ApiSpec) error {
if err := genEtcOnce(dir, "dev", cfg, api); err != nil {
return err
}
if err := genEtcOnce(dir, "test", cfg, api); err != nil {
return err
}
if err := genEtcOnce(dir, "pro", cfg, api); err != nil {
return err
}
return nil
}
func genEtcOnce(dir, mode string, cfg *config.Config, api *spec.ApiSpec) error {
host := "0.0.0.0"
port := strconv.Itoa(defaultPort)
return genFile(fileGenConfig{
dir: dir,
subdir: etcDir,
filename: fmt.Sprintf("config-%s.yaml", mode),
templateName: "etcTemplate",
category: category,
templateFile: etcTemplateFile,
builtinTemplate: etcTemplate,
data: map[string]string{
"host": host,
"port": port,
},
})
}
genConfig
config与etc对应,也比较简单。
代码如下:
text
package config
import "fmt"
type Config struct {
Host string `yaml:"Host"`
Port string `yaml:"Port"`
Jwt struct {
Secret string `yaml:"Secret"`
ExpireHours int64 `yaml:"ExpireHours"`
} `yaml:"Jwt"`
MySQL struct {
User string `yaml:"User"`
Password string `yaml:"Password"`
Host string `yaml:"Host"`
Port string `yaml:"Port"`
Database string `yaml:"Database"`
} `yaml:"MySQL"`
Redis struct {
Host string `yaml:"Host"`
Port string `yaml:"Port"`
Password string `yaml:"Password"`
} `yaml:"Redis"`
}
func (c *Config) MySQLDsn() string {
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
c.MySQL.User, c.MySQL.Password, c.MySQL.Host, c.MySQL.Port, c.MySQL.Database)
}
go
package gogen
import (
_ "embed"
"ginctl/api/spec"
"ginctl/config"
"ginctl/util/format"
)
const (
configFile = "config"
)
//go:embed config.tpl
var configTemplate string
func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
filename, err := format.FileNamingFormat(cfg.NamingFormat, configFile)
if err != nil {
return err
}
return genFile(fileGenConfig{
dir: dir,
subdir: configDir,
filename: filename + ".go",
templateName: "configTemplate",
category: category,
templateFile: configTemplateFile,
builtinTemplate: configTemplate,
data: map[string]string{},
})
}
genMain
main文件的逻辑比较简单。
text
package main
import (
"fmt"
"net/http"
"os"
"time"
"github.com/gin-gonic/gin"
"gopkg.in/yaml.v3"
{{.importPackages}}
)
// 环境默认是[dev]
var envMode = "dev"
func main() {
readEnv()
conf := readConfig()
startServer(conf)
}
func readEnv() {
value := os.Getenv("XEnvMode")
if value != "" {
envMode = value
}
}
func readConfig() config.Config {
data, err := os.ReadFile(fmt.Sprintf("etc/config-%s.yaml", envMode))
if err != nil {
panic(err)
}
var conf config.Config
err = yaml.Unmarshal(data, &conf)
if err != nil {
panic(err)
}
return conf
}
func startServer(conf config.Config) {
engine := gin.New()
engine.Use(gin.Logger())
engine.Use(gin.Recovery())
// cors in dev mode
corsInDevMode(engine)
s := &http.Server{
Addr: fmt.Sprintf(":%s", conf.Port),
Handler: engine,
ReadTimeout: time.Second * 60,
WriteTimeout: time.Second * 60,
MaxHeaderBytes: 1 << 20,
}
// init service context
svcCtx := svc.NewServiceContext(conf)
// setup routes
handler.RegisterHandlers(engine, svcCtx)
// start server
fmt.Println("[gin] ==> ready to work!")
_ = s.ListenAndServe()
}
func corsInDevMode(engine *gin.Engine) {
if envMode == "dev" {
engine.Use(func(c *gin.Context) {
method := c.Request.Method
origin := c.Request.Header.Get("Origin")
if origin != "" {
c.Header("Access-Control-Allow-Origin", "*") // 可将将 * 替换为指定的域名
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
c.Header("Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept, Authorization")
c.Header("Access-Control-Expose-Headers",
"Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Cache-Control, Content-Language, Content-Type")
c.Header("Access-Control-Allow-Credentials", "true")
}
if method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
}
c.Next()
})
}
}
go
package gogen
import (
_ "embed"
"fmt"
"strings"
"ginctl/api/spec"
"ginctl/config"
"ginctl/util/pathx"
)
//go:embed main.tpl
var mainTemplate string
func genMain(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
return genFile(fileGenConfig{
dir: dir,
subdir: "",
filename: "main.go",
templateName: "mainTemplate",
category: category,
templateFile: mainTemplateFile,
builtinTemplate: mainTemplate,
data: map[string]string{
"importPackages": genMainImports(rootPkg),
//"serviceName": configName,
},
})
}
func genMainImports(parentPkg string) string {
var imports []string
imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, configDir)))
imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, handlerDir)))
imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, contextDir)))
return strings.Join(imports, "\n\t")
}
genServiceContext
这里删除了一些配置项,显得更简单。
text
package svc
import (
{{.configImport}}
)
type ServiceContext struct {
Config {{.config}}
}
func NewServiceContext(c {{.config}}) *ServiceContext {
return &ServiceContext{
Config: c,
}
}
go
package gogen
import (
_ "embed"
"ginctl/api/spec"
"ginctl/config"
"ginctl/util/format"
"ginctl/util/pathx"
)
const contextFilename = "service_context"
//go:embed svc.tpl
var contextTemplate string
func genServiceContext(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
if err != nil {
return err
}
configImport := "\"" + pathx.JoinPackages(rootPkg, configDir) + "\""
return genFile(fileGenConfig{
dir: dir,
subdir: contextDir,
filename: filename + ".go",
templateName: "contextTemplate",
category: category,
templateFile: contextTemplateFile,
builtinTemplate: contextTemplate,
data: map[string]string{
"configImport": configImport,
"config": "config.Config",
},
})
}
genTypes
这里修改很少。
text
// Code generated by ginctl. DO NOT EDIT.
package types{{if .containsTime}}
import (
"time"
){{end}}
{{.types}}
go
package gogen
import (
_ "embed"
"fmt"
"io"
"os"
"path"
"strings"
"ginctl/api/spec"
apiutil "ginctl/api/util"
"ginctl/config"
"ginctl/util"
"ginctl/util/format"
)
const typesFile = "types"
//go:embed types.tpl
var typesTemplate string
// BuildTypes gen types to string
func BuildTypes(types []spec.Type) (string, error) {
var builder strings.Builder
first := true
for _, tp := range types {
if first {
first = false
} else {
builder.WriteString("\n\n")
}
if err := writeType(&builder, tp); err != nil {
return "", apiutil.WrapErr(err, "Type "+tp.Name()+" generate error")
}
}
return builder.String(), nil
}
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
val, err := BuildTypes(api.Types)
if err != nil {
return err
}
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, typesFile)
if err != nil {
return err
}
typeFilename = typeFilename + ".go"
filename := path.Join(dir, typesDir, typeFilename)
os.Remove(filename)
return genFile(fileGenConfig{
dir: dir,
subdir: typesDir,
filename: typeFilename,
templateName: "typesTemplate",
category: category,
templateFile: typesTemplateFile,
builtinTemplate: typesTemplate,
data: map[string]any{
"types": val,
"containsTime": false,
},
})
}
func writeType(writer io.Writer, tp spec.Type) error {
structType, ok := tp.(spec.DefineStruct)
if !ok {
return fmt.Errorf("unspport struct type: %s", tp.Name())
}
_, err := fmt.Fprintf(writer, "type %s struct {\n", util.Title(tp.Name()))
if err != nil {
return err
}
if err := writeMember(writer, structType.Members); err != nil {
return err
}
_, err = fmt.Fprintf(writer, "}")
return err
}
func writeMember(writer io.Writer, members []spec.Member) error {
for _, member := range members {
if member.IsInline {
if _, err := fmt.Fprintf(writer, "%s\n", strings.Title(member.Type.Name())); err != nil {
return err
}
continue
}
if err := writeProperty(writer, member.Name, member.Tag, member.GetComment(), member.Type, 1); err != nil {
return err
}
}
return nil
}
genRoutes
gin组织路由的方式与go-zero有很大区别,所以这里修改比较多。
解析的思路是一样的,遍历所有的group,逐条处理。
生成的代码则要改为gin的。
代码如下:
go
package gogen
import (
"fmt"
"ginctl/api/spec"
"ginctl/config"
"ginctl/util/format"
"ginctl/util/pathx"
"github.com/zeromicro/go-zero/core/collection"
"os"
"path"
"sort"
"strconv"
"strings"
"text/template"
)
const (
jwtTransKey = "jwtTransition"
routesFilename = "routes"
routesTemplate = `// Code generated by goctl. DO NOT EDIT.
// goctl {{.version}}
package handler
import (
"net/http"{{if .hasTimeout}}
"time"{{end}}
"github.com/gin-gonic/gin"
{{.importPackages}}
)
func RegisterHandlers(r *gin.Engine, serverCtx *svc.ServiceContext) {
{{.routesAdditions}}
}
`
routesAdditionTemplate = `
g{{.index}} := r.Group("{{.prefix}}")
{{.routes}}
`
)
var mapping = map[string]string{
"delete": "http.MethodDelete",
"get": "http.MethodGet",
"head": "http.MethodHead",
"post": "http.MethodPost",
"put": "http.MethodPut",
"patch": "http.MethodPatch",
"connect": "http.MethodConnect",
"options": "http.MethodOptions",
"trace": "http.MethodTrace",
}
type (
group struct {
routes []route
jwtEnabled bool
signatureEnabled bool
authName string
timeout string
middlewares []string
prefix string
jwtTrans string
maxBytes string
}
route struct {
method string
path string
handler string
doc string
}
)
func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
var builder strings.Builder
groups, err := getRoutes(api)
if err != nil {
return err
}
templateText, err := pathx.LoadTemplate(category, routesAdditionTemplateFile, routesAdditionTemplate)
if err != nil {
return err
}
var hasMiddleware bool
gt := template.Must(template.New("groupTemplate").Parse(templateText))
for i, g := range groups {
var gbuilder strings.Builder
// 中间件
if len(g.middlewares) > 0 {
hasMiddleware = true
params := g.middlewares
for _, v := range params {
fmt.Fprintf(&gbuilder, `g%d.Use(middleware.New%sMiddleware().Handle(serverCtx))
`, i, v)
}
}
// group 下的子路由
for _, r := range g.routes {
fmt.Fprintf(&gbuilder, `g%d.Handle(%s,"%s",%s)
`, i, r.method, r.path, r.handler)
}
var routes string
routes = strings.TrimSpace(gbuilder.String())
if err := gt.Execute(&builder, map[string]string{
"index": strconv.Itoa(i),
"routes": routes,
"prefix": g.prefix,
}); err != nil {
return err
}
}
routeFilename, err := format.FileNamingFormat(cfg.NamingFormat, routesFilename)
if err != nil {
return err
}
routeFilename = routeFilename + ".go"
filename := path.Join(dir, handlerDir, routeFilename)
os.Remove(filename)
return genFile(fileGenConfig{
dir: dir,
subdir: handlerDir,
filename: routeFilename,
templateName: "routesTemplate",
category: category,
templateFile: routesTemplateFile,
builtinTemplate: routesTemplate,
data: map[string]any{
"importPackages": genRouteImports(rootPkg, hasMiddleware, api),
"routesAdditions": strings.TrimSpace(builder.String()),
},
})
}
func genRouteImports(parentPkg string, hasMiddleware bool, api *spec.ApiSpec) string {
importSet := collection.NewSet()
importSet.AddStr(fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)))
for _, group := range api.Service.Groups {
for _, route := range group.Routes {
folder := route.GetAnnotation(groupProperty)
if len(folder) == 0 {
folder = group.GetAnnotation(groupProperty)
if len(folder) == 0 {
continue
}
}
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder),
pathx.JoinPackages(parentPkg, handlerDir, folder)))
}
}
imports := importSet.KeysStr()
sort.Strings(imports)
projectSection := strings.Join(imports, "\n\t")
var depSection string
if hasMiddleware {
depSection = fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, middlewareDir))
}
return fmt.Sprintf("%s\n\n\t%s", projectSection, depSection)
}
func getRoutes(api *spec.ApiSpec) ([]group, error) {
var routes []group
for _, g := range api.Service.Groups {
var groupedRoutes group
for _, r := range g.Routes {
handler := getHandlerName(r)
handler = handler + "(serverCtx)"
folder := r.GetAnnotation(groupProperty)
if len(folder) > 0 {
handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:]
} else {
folder = g.GetAnnotation(groupProperty)
if len(folder) > 0 {
handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:]
}
}
groupedRoutes.routes = append(groupedRoutes.routes, route{
method: mapping[r.Method],
path: r.Path,
handler: handler,
})
}
groupedRoutes.timeout = g.GetAnnotation("timeout")
groupedRoutes.maxBytes = g.GetAnnotation("maxBytes")
jwt := g.GetAnnotation("jwt")
if len(jwt) > 0 {
groupedRoutes.authName = jwt
groupedRoutes.jwtEnabled = true
}
jwtTrans := g.GetAnnotation(jwtTransKey)
if len(jwtTrans) > 0 {
groupedRoutes.jwtTrans = jwtTrans
}
signature := g.GetAnnotation("signature")
if signature == "true" {
groupedRoutes.signatureEnabled = true
}
middleware := g.GetAnnotation("middleware")
if len(middleware) > 0 {
groupedRoutes.middlewares = append(groupedRoutes.middlewares,
strings.Split(middleware, ",")...)
}
prefix := g.GetAnnotation(spec.RoutePrefixKey)
prefix = strings.ReplaceAll(prefix, `"`, "")
prefix = strings.TrimSpace(prefix)
if len(prefix) > 0 {
prefix = path.Join("/", prefix)
groupedRoutes.prefix = prefix
}
routes = append(routes, groupedRoutes)
}
return routes, nil
}
func toPrefix(folder string) string {
return strings.ReplaceAll(folder, "/", "")
}
genResponse
这个是自定义新增的,是为了统一处理response。
text
// Code generated by ginctl. DO NOT EDIT.
package response
import (
"fmt"
"net/http"
"github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
// 错误分为以下几类
// - 成功
// - 不规范错误
// - 规范错误
// - 未使用的错误
const (
ErrOk = iota + 0
)
const (
ErrUnspecified = -1 // 不规范错误,未分类的错误
)
const (
ErrInvalidToken = iota + 100
ErrInvalidParam
ErrNoPermission
ErrNeedVip
ErrNeedQuota
)
const (
ErrInternalError = iota + 500
)
// 此类型代表是规范错误
type GyError struct {
Code int64 `json:"code"`
Msg string `json:"msg"`
}
func (e GyError) Error() string {
return e.Msg
}
// 使用默认message
func NewGyError(code int64) GyError {
return GyError{
Code: code,
Msg: defaultMessage(code),
}
}
// 指定message
func NewGyErrorReplaceMsg(code int64, message string) GyError {
return GyError{
Code: code,
Msg: message,
}
}
// 追加message
func NewGyErrorAppendMsg(code int64, message string) GyError {
return GyError{
Code: code,
Msg: fmt.Sprintf("%s >> %s", defaultMessage(code), message),
}
}
// 默认的->错误码说明信息
func defaultMessage(code int64) string {
var msg string
switch code {
case ErrOk:
msg = "success"
case ErrUnspecified:
msg = "[error]"
case ErrInvalidToken:
msg = "token不合法"
case ErrInvalidParam:
msg = "参数不合法"
case ErrNoPermission:
msg = "无权限"
case ErrInternalError:
msg = "内部错误"
default:
msg = "未使用"
}
return msg
}
type Body struct {
Code int64 `json:"code"`
Msg string `json:"msg"`
Result any `json:"result"`
}
// Response
// 如果没有错误,err传空
// 如果有错误:err非空
// - 非规范错误:未指定错误码
// - 规范错误:根据错误类型,指定错误码
func Response(c *gin.Context, resp any, err error) {
gyError := formatError(err)
body := Body{
Code: gyError.Code,
Msg: gyError.Msg,
Result: resp,
}
c.JSON(http.StatusOK, body)
}
// 规范error
func formatError(err error) (gyError GyError) {
if err == nil {
// 1. [没有错误]
gyError = NewGyError(ErrOk)
} else {
var target GyError
if !errors.As(err, &target) {
// 2.1 [不规范错误]
gyError = NewGyErrorAppendMsg(ErrUnspecified, err.Error())
} else {
// 2.2 [规范错误]
gyError = target
}
}
return
}
// 请求不合法
func ResponseInvalid(c *gin.Context, httpCode int, err error) {
c.String(httpCode, formatError(err).Error())
}
go
package gogen
import (
_ "embed"
"os"
"path"
"ginctl/api/spec"
"ginctl/config"
)
//go:embed response.tpl
var responseTemplate string
func genResponse(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
responseFilename := "response" + ".go"
filename := path.Join(dir, responseDir, responseFilename)
os.Remove(filename)
return genFile(fileGenConfig{
dir: dir,
subdir: responseDir,
filename: "response.go",
templateName: "responseTemplate",
category: category,
templateFile: responseTemplateFile,
builtinTemplate: responseTemplate,
data: map[string]string{},
})
}
genHandlers
主要是修改tpl中的代码,要改为gin的。
解析与生成代码的逻辑不需要太多修改。
text
package {{.PkgName}}
import (
{{if .HasRequest}}"net/http"{{end}}
"github.com/gin-gonic/gin"
{{.ImportPackages}}
)
{{.Comment}}
func {{.HandlerName}}(svcCtx *svc.ServiceContext) gin.HandlerFunc {
return func(c *gin.Context) {
{{if .HasRequest}}var req types.{{.RequestType}}
if err := c.ShouldBind(&req); err != nil {
response.ResponseInvalid(c, http.StatusBadRequest, err)
return
}
{{end}}l := {{.LogicName}}.New{{.LogicType}}(c, svcCtx)
{{if .HasResp}}resp, {{end}}err := l.{{.Call}}({{if .HasRequest}}&req{{end}})
{{if .HasResp}}response.Response(c, resp, err){{else}}response.Response(c, nil, err){{end}}
}
}
go
package gogen
import (
_ "embed"
"fmt"
"path"
"strings"
"ginctl/api/spec"
"ginctl/config"
"ginctl/util"
"ginctl/util/format"
"ginctl/util/pathx"
)
const defaultLogicPackage = "logic"
//go:embed handler.tpl
var handlerTemplate string
func genHandler(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
handler := getHandlerName(route)
handlerPath := getHandlerFolderPath(group, route)
pkgName := handlerPath[strings.LastIndex(handlerPath, "/")+1:]
logicName := defaultLogicPackage
if handlerPath != handlerDir {
handler = strings.Title(handler)
logicName = pkgName
}
filename, err := format.FileNamingFormat(cfg.NamingFormat, handler)
if err != nil {
return err
}
comment := parseComment(route)
return genFile(fileGenConfig{
dir: dir,
subdir: getHandlerFolderPath(group, route),
filename: filename + ".go",
templateName: "handlerTemplate",
category: category,
templateFile: handlerTemplateFile,
builtinTemplate: handlerTemplate,
data: map[string]any{
"PkgName": pkgName,
"ImportPackages": genHandlerImports(group, route, rootPkg),
"HandlerName": handler,
"RequestType": util.Title(route.RequestTypeName()),
"LogicName": logicName,
"LogicType": strings.Title(getLogicName(route)),
"Call": strings.Title(strings.TrimSuffix(handler, "Handler")),
"HasResp": len(route.ResponseTypeName()) > 0,
"HasRequest": len(route.RequestTypeName()) > 0,
"HasDoc": len(route.JoinedDoc()) > 0,
"Doc": getDoc(route.JoinedDoc()),
"Comment": comment,
},
})
}
func genHandlers(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
for _, group := range api.Service.Groups {
for _, route := range group.Routes {
if err := genHandler(dir, rootPkg, cfg, group, route); err != nil {
return err
}
}
}
return nil
}
func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) string {
imports := []string{
fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, getLogicFolderPath(group, route))),
fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)),
fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, responseDir)),
}
if len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir)))
}
return strings.Join(imports, "\n\t")
}
func getHandlerBaseName(route spec.Route) (string, error) {
handler := route.Handler
handler = strings.TrimSpace(handler)
handler = strings.TrimSuffix(handler, "handler")
handler = strings.TrimSuffix(handler, "Handler")
return handler, nil
}
func getHandlerFolderPath(group spec.Group, route spec.Route) string {
folder := route.GetAnnotation(groupProperty)
if len(folder) == 0 {
folder = group.GetAnnotation(groupProperty)
if len(folder) == 0 {
return handlerDir
}
}
folder = strings.TrimPrefix(folder, "/")
folder = strings.TrimSuffix(folder, "/")
return path.Join(handlerDir, folder)
}
func getHandlerName(route spec.Route) string {
handler, err := getHandlerBaseName(route)
if err != nil {
panic(err)
}
return handler + "Handler"
}
func getLogicName(route spec.Route) string {
handler, err := getHandlerBaseName(route)
if err != nil {
panic(err)
}
return handler + "Logic"
}
func parseComment(r spec.Route) string {
if r.AtDoc.Text != "" {
return strings.Trim(r.AtDoc.Text, "\"")
}
if len(r.HandlerDoc) != 0 {
str := ""
for _, d := range r.HandlerDoc {
str += fmt.Sprintf("\n%s", d)
}
return str
}
return ""
}
genLogic
与handler类似,要改一下tpl。
text
package {{.pkgName}}
import (
"log/slog"
"github.com/gin-gonic/gin"
{{.imports}}
)
type {{.logic}} struct {
c *gin.Context
svcCtx *svc.ServiceContext
logger *slog.Logger
}
func New{{.logic}}(c *gin.Context, svcCtx *svc.ServiceContext) *{{.logic}} {
return &{{.logic}}{
c: c,
svcCtx: svcCtx,
logger: slog.With("m", "{{.logic}}"),
}
}
func (l *{{.logic}}) {{.function}}({{.request}}) {{.responseType}} {
// todo: add your logic here and delete this line
{{.returnString}}
}
go
package gogen
import (
_ "embed"
"fmt"
"path"
"strconv"
"strings"
"ginctl/api/parser/g4/gen/api"
"ginctl/api/spec"
"ginctl/config"
"ginctl/util/format"
"ginctl/util/pathx"
)
//go:embed logic.tpl
var logicTemplate string
func genLogic(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
for _, g := range api.Service.Groups {
for _, r := range g.Routes {
err := genLogicByRoute(dir, rootPkg, cfg, g, r)
if err != nil {
return err
}
}
}
return nil
}
func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
logic := getLogicName(route)
goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
if err != nil {
return err
}
imports := genLogicImports(route, rootPkg)
var responseString string
var returnString string
var requestString string
if len(route.ResponseTypeName()) > 0 {
resp := responseGoTypeName(route, typesPacket)
responseString = "(resp " + resp + ", err error)"
returnString = "return"
} else {
responseString = "error"
returnString = "return nil"
}
if len(route.RequestTypeName()) > 0 {
requestString = "req *" + requestGoTypeName(route, typesPacket)
}
subDir := getLogicFolderPath(group, route)
return genFile(fileGenConfig{
dir: dir,
subdir: subDir,
filename: goFile + ".go",
templateName: "logicTemplate",
category: category,
templateFile: logicTemplateFile,
builtinTemplate: logicTemplate,
data: map[string]any{
"pkgName": subDir[strings.LastIndex(subDir, "/")+1:],
"imports": imports,
"logic": strings.Title(logic),
"function": strings.Title(strings.TrimSuffix(logic, "Logic")),
"responseType": responseString,
"returnString": returnString,
"request": requestString,
"hasDoc": len(route.JoinedDoc()) > 0,
"doc": getDoc(route.JoinedDoc()),
},
})
}
func getLogicFolderPath(group spec.Group, route spec.Route) string {
folder := route.GetAnnotation(groupProperty)
if len(folder) == 0 {
folder = group.GetAnnotation(groupProperty)
if len(folder) == 0 {
return logicDir
}
}
folder = strings.TrimPrefix(folder, "/")
folder = strings.TrimSuffix(folder, "/")
return path.Join(logicDir, folder)
}
func genLogicImports(route spec.Route, parentPkg string) string {
var imports []string
imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)))
if shallImportTypesPackage(route) {
imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir)))
}
return strings.Join(imports, "\n\t")
}
func onlyPrimitiveTypes(val string) bool {
fields := strings.FieldsFunc(val, func(r rune) bool {
return r == '[' || r == ']' || r == ' '
})
for _, field := range fields {
if field == "map" {
continue
}
// ignore array dimension number, like [5]int
if _, err := strconv.Atoi(field); err == nil {
continue
}
if !api.IsBasicType(field) {
return false
}
}
return true
}
func shallImportTypesPackage(route spec.Route) bool {
if len(route.RequestTypeName()) > 0 {
return true
}
respTypeName := route.ResponseTypeName()
if len(respTypeName) == 0 {
return false
}
if onlyPrimitiveTypes(respTypeName) {
return false
}
return true
}
genMiddleware
这里只是稍作修改。
text
package middleware
import (
"github.com/gin-gonic/gin"
{{.importPackages}}
)
type {{.name}} struct {
}
func New{{.name}}() *{{.name}} {
return &{{.name}}{}
}
func (m *{{.name}})Handle(svcCtx *svc.ServiceContext) gin.HandlerFunc {
return func(c *gin.Context) {
// TODO generate middleware implement function, delete after code implementation
// Passthrough to next handler if need
c.Next()
}
}
go
package gogen
import (
_ "embed"
"fmt"
"ginctl/util/pathx"
"strings"
"ginctl/api/spec"
"ginctl/config"
"ginctl/util/format"
)
//go:embed middleware.tpl
var middlewareImplementCode string
func genMiddleware(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
middlewares := getMiddleware(api)
for _, item := range middlewares {
middlewareFilename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "_middleware"
filename, err := format.FileNamingFormat(cfg.NamingFormat, middlewareFilename)
if err != nil {
return err
}
name := strings.TrimSuffix(item, "Middleware") + "Middleware"
err = genFile(fileGenConfig{
dir: dir,
subdir: middlewareDir,
filename: filename + ".go",
templateName: "contextTemplate",
category: category,
templateFile: middlewareImplementCodeFile,
builtinTemplate: middlewareImplementCode,
data: map[string]string{
"name": strings.Title(name),
"importPackages": genMiddlewareImports(rootPkg, api),
},
})
if err != nil {
return err
}
}
return nil
}
func genMiddlewareImports(parentPkg string, api *spec.ApiSpec) string {
return fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir))
}