引言

前篇文章 中我们搞定了最基础的三件事: 启动, 路由, 可观测性. 正因为基础, 所以这部分代码的变化频率也是最低的.

这次我们就集中精力来处理变化稍快的业务层. 按照我们之前的规则, 业务层其实也可以分成两部分: API 和 具体实现. 其中 API 较为稳定, 很少出现破环性的变更, 一般维护良好的 API 会充分考虑其兼容性. 相较而言, 具体实现的变化速度就快得多了.

由于这篇文章的主要目的还是带着大家一起写代码, 所以如何得到一个设计良好的 API 就不是本篇的重点了, 想深入了解的话推荐读读看 <软件设计哲学> 这本书. 这里我就直接采用 Google API 设计指南 中的接口方案了, 可以稍微浏览下, 有个大概的概念.

至于业务场景, 就假想一个购物车的场景吧, 之后正好可以用来说明事务处理相关的流程. 数据库选择 SQLite, 这样方便在本地把 Demo 跑起来.

另外, 这次我考虑把篇幅稍微控制一下, 上次一口气写太多了, 估计读起来也挺累的 XD.

完整的示例代码依然在: https://github.com/gota33/micro-service

好了, 下面就正式开工吧.

资源

在我们假想的超简单的购物车场景中, 只存在三种资源:

  1. customer 顾客
  2. cart 购物车
  3. item 商品

每种资源都包含五种标准方法:

标准方法 HTTP 映射 HTTP 请求正文 HTTP 响应正文
List GET <collection> 资源列表
Get GET <resource> 资源
Create POST <collection> 资源 资源
Update PUT or PATCH <resource> 资源 资源
Delete DELETE <resource> 不适用

本篇的计划, 就是谈谈如何实现这些标准方法, 对于一般的业务场景而言, 有这五个 API 就已经足够了.

顺带一提, 我设计程序的时候习惯先用较粗的粒度, 自顶向下在脑子里过一遍, 不过具体设计和实现的时候还是按照自底向上的顺序. 因为这样做比较容易控制底层模块的正交性, 避免重复. 另外这样也比较容易做到机制与策略分离, 这是一种非常好的控制复杂性的手段, 感兴趣的话可以到 <UNIX程序设计艺术> 中找找看.

总而言之, 由于商品在依赖关系的最内层, 我们就以它来举例好了.

数据库

商品表包含三个业务字段: title, price, num, 以及两个自生成字段: id, create_time.

字段说明就直接看建表语句吧.

-- internal/cli/init.sql

create table if not exists item 
-- 商品表
(
    id          integer primary key autoincrement,                  -- 主键
    title       text           not null check (length(title) > 0),  -- 标题
    price       decimal(12, 2) not null check (price >= 0),         -- 价格
    num         integer        not null check (num >= 0),           -- 在售数量
    create_time timestamp      not null default current_timestamp   -- 创建时间
);

接下来, 我们希望程序在启动时初始化这个数据库.

第一步, 先加上 SQLite 的配置解析.

// internal/cli/config/sqlite/v1/sqlite.go

package v1

import (
	"context"
	"database/sql"
	"time"

	"github.com/gota33/initializr"
	_ "github.com/mattn/go-sqlite3"
	"github.com/sirupsen/logrus"
)

type Options struct {
	DSN string `json:"dsn"`
}

func New(res initializr.Resource, key string) (db *sql.DB, close func(), err error) {
	var opts Options
	if err = res.Scan(key, &opts); err != nil {
		return
	}
	if db, err = sql.Open("sqlite3", opts.DSN); err != nil {
		return
	}

	db.SetMaxOpenConns(1)
	db.SetMaxIdleConns(1)

	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
	defer cancel()

	if err = db.PingContext(ctx); err != nil {
		return
	}

	close = func() {
		if closeErr := db.Close(); closeErr != nil {
			logrus.WithError(err).Warnf("Close SQLite error")
		}
	}
	return
}

整体看上去其实和 MySQL 差不多, 但需要注意的是: SQLite 在以非只读方式打开时只支持一个并发, 所以这里我们就要把线程池的连接数设为 1.

还有对应的配置文件:

// internal/cli/config.json
{
  "app": {
    "name": "demo"
  },
  "sqlite": {
    "dsn": "./demo.db"
  }
}

第二步, 在启动时执行 init.sql

// internal/cli/cli.go
// ...
var (
	//go:embed config.json
	defaultConfig []byte

	//go:embed init.sql
	initSql string
)
// ...
func runServer(c *Context) (err error) {
	// ...
	// 读取配置
	if configUrl := flagConfigUrl.Get(c); configUrl != "" {
		res, err = initializr.FromJsonRemote(configUrl)
	} else {
		res, err = initializr.FromJson(bytes.NewReader(defaultConfig))
	}
	if err != nil {
		return
	}
	// 初始化数据库连接
	if config.RDS, closeRDS, err = initsqlite.New(res, "sqlite"); err != nil {
		return
	}
	// 退出前释放连接
	defer closeRDS()
	// 初始化数据库
	if _, err = config.RDS.ExecContext(c.Context, initSql); err != nil {
		return
	}
	// ...
}

这里我们稍微修改了下配置载入的流程. 如果没有配置远程仓库, 则从本地的 internal/cli/config.json 读取配置. 载入成功后执行 internal/cli/init.sql 初始化数据库, 最后启动服务器.

实体类

数据库往上一层就该定义实体类了.

// internal/service/item/entity.go
package item

import "time"

type Entity struct {
	ID         int64     `json:"id,string"`
	Title      string    `json:"title" validate:"required"`
	Price      float64   `json:"price" validate:"required,min=0"`
	Num        int64     `json:"num" validate:"required,min=0"`
	CreateTime time.Time `json:"createTime"`
}

由于我们采用的是声明式的参数绑定方式, 所以参数解析和校验的规则就都放在 Label 中了.

接口实现

实体类准备好之后就可以开始写数据访问层了, 事情也开始变得有意思了起来.

Get

先从最简单的开始. 实现前记得先看下规范.

// internal/service/item/entity.go
// ...
type dao struct {
	db *sql.DB
}

func (d dao) Get(ctx context.Context, id string) (e Entity, err error) {
	const script = "select id, title, price, num, create_time from item where id = ? limit 1"
	row := d.db.QueryRowContext(ctx, script, id)
	err =  row.Scan(&e.ID, &e.Title, &e.Price, &e.Num, &e.CreateTime)
	return
}

唯一要注意的是: 由于查询用的是 QueryRowContext() 所以记得在对应的 SQL 里加上 limit 1 即使我们知道主键本来最多就只能查出一行, 算是个好习惯吧.

顺便把服务和路由加好, 方便一会儿测试.

首先是服务.

// internal/service/item/service.go
package item

type Service struct {
	dao dao
}

func New(db *sql.DB) Service {
	return Service{dao: dao{db}}
}

type GetRequest struct {
	ItemID string `param:"itemID"`
}

func (srv Service) Get(ctx context.Context, req GetRequest) (res Entity, err error) {
	return srv.dao.Get(ctx, req.ItemID)
}

然后路由, 这里路由我先都列出来, 后面就不重复写了.

// internal/server/router.go
// ...

func (r router) setup() {
	// ...
	r.item()
}

func (r router) item() {
	srv := item.New(r.config.RDS)

	g := r.Group("items")
	g.Post("", handler(srv.Create))
	g.Get("", handler(srv.List))
	g.Get(":itemID", handler(srv.Get))
	g.Patch(":itemID", handler(srv.Update))
	g.Delete(":itemID", handler(srv.Delete))
}

不过先别急着测试, 我们还有一点增强工作要做.

说来也怪, Fiber 的参数绑定, 支持: body, query, reqHeader. 就是没有 URL 路径参数, 所以上面代码中的 param 这个 Label 实际上是我自定义的. 这里就得先补上相应的解析函数了.

// internal/server/server.go
// ...

func paramParser(c *fiber.Ctx, req any) (err error) {
	const tagName = "param"

	keys := make([]string, 0)
	t := reflect.TypeOf(req).Elem()
	for i := 0; i < t.NumField(); i++ {
		f := t.Field(i)
		if v := f.Tag.Get(tagName); v != "" {
			keys = append(keys, v)
		}
	}

	if len(keys) == 0 {
		return
	}

	var (
		decoder *mapstructure.Decoder
		config  = &mapstructure.DecoderConfig{
			TagName: tagName,
			Result:  req,
		}
	)
	if decoder, err = mapstructure.NewDecoder(config); err != nil {
		return
	}

	params := make(map[string]any, len(keys))
	for _, key := range keys {
		params[key] = c.Params(key)
	}
	return decoder.Decode(params)
}

这里用到了一个将 map 内容绑定到 struct 的第三方库 github.com/mitchellh/mapstructure.

在 handler 中调用它.

// internal/server/server.go
// ...

func handler[Request any, Response any](h func(context.Context, Request) (Response, error)) fiber.Handler {
	return func(c *fiber.Ctx) (err error) {
		// ...
		if err = c.BodyParser(&req); err != nil && !errors.Is(err, fiber.ErrUnprocessableEntity) {
			return
		}
		if err = c.QueryParser(&req); err != nil {
			return
		}
		if err = paramParser(c, &req); err != nil {
			return
		}
		// ...
	}
}

还差一点了, 另一个增强工作是要支持返回自定义的 http code, 因为 Delete 接口需要返回 204 No Content. 像我们现在这样不管三七二十一就返回 JSON 肯定是不行的.

比较简单的方案是, 如果返回值为 int 类型, 就作为状态码返回, 如果是其他类型就编码成 JSON 返回. 因为 Google API 设计规范中只要有返回值, 就一定是一个 JSON Object, 所以并不会产生冲突.

但是还有一点小问题要处理, go1.18 的中的泛型还不支持类型断言 (官方说在 go1.19 中可能会支持), 所以我们就得用一点小技巧了.

// internal/server/server.go
// ...

func handler[Request any, Response any](h func(context.Context, Request) (Response, error)) fiber.Handler {
	return func(c *fiber.Ctx) (err error) {
		// ...
		// 这里先把泛型类型转为 any
		var temp any = res
		// 然后就可以顺利做类型断言了
		switch v := temp.(type) {
		case int:
			return c.SendStatus(v)
		default:
			return c.JSON(res)
		}
	}
}

最后来看看效果:

$ curl http://localhost:8080/items/1

# {"id":"1","title":"New title 1","price":20,"num":10,"createTime":"2022-04-26T07:10:59Z"}

格式化一下:

{
    "id": "1",
    "title": "New title 1",
    "price": 20,
    "num": 10,
    "createTime": "2022-04-26T07:10:59Z"
}

完全符合预期, 那么就继续吧.

Create

好了, 有意思的地方开始了. 按照 API 规范, 里面要求 Create 返回插入后的资源. 这很好理解, 毕竟有些自生成字段, 例如: id, createTime 之类, 只能由服务端返回.

虽然有些数据库支持读取自生成字段的函数, 但 SQLite 显然不支持获取除主键外的自生成字段. 而且即使支持, 通过函数获取也无法得到因为触发器而发生变更的字段.

因此, 最保险的方式还是通过主键从数据库把插入后的记录再读出来. 这时就需要用到 SQL 事务了, 否则就会遇到脏读的问题.

OK, 背景交代完毕, 我们来看看最直白的代码是怎样的.

// internal/service/item/entity.go
// ...
func (d dao) Create(ctx context.Context, e Entity) (next Entity, err error) {
	const (
		sqlCreate = "insert into item (title, price, num) values (?, ?, ?)"
		sqlGet    = "select id, title, price, num, create_time from item where id = ? limit 1"
	)
	var (
		tx *sql.Tx
		sr sql.Result
		id int64
	)
	if tx, err = d.db.BeginTx(ctx, nil); err != nil {
		return
	}

	defer func() {
		if err != nil {
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
				logrus.WithError(rollbackErr).Warnf("Rollback error")
			} else {
				err = tx.Commit()
			}
		}
	}()

	if sr, err = tx.ExecContext(ctx, sqlCreate, e.Title, e.Price, e.Num); err != nil {
		return
	}
	if id, err = sr.LastInsertId(); err != nil {
		return
	}
	row := tx.QueryRowContext(ctx, sqlGet, strconv.FormatInt(id, 10))
	err = row.Scan(&next.ID, &next.Title, &next.Price, &next.Num, &next.CreateTime)
	return
}

OMG! 我们写了快 30 行代码, 然而里面就只有三五行代码是业务相关的, 剩下的全都是样板代码. 这直接违背了正交性原则, 绝对不能忍.

这里也不得不吐槽下 Go 的 sql 包, 确实还有不少改进空间, 至今不明白为什么就不能多加点接口类型. 不过吐槽归吐槽, 下面还是进入自己动手, 丰衣足食的环节.

首先, 做一些准备工作, 把 sql 包中的一些常用类型抽象成接口.

// internal/service/entity/dao.go
package entity

type Scanner interface {
	Scan(dest ...any) error
}

type SQLCmd interface {
	QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
	ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
	QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}

type SQLBeginTx interface {
	SQLCmd
	BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}

type SQLTx interface {
	SQLCmd
	Commit() error
	Rollback() error
}

稍微说明一下:

  1. Scanner : *sql.Row*sql.Rows
  2. SQLCmd : 基本的 SQL 读写接口
  3. SQLBeginTx : *sql.DB
  4. SQLTx : *sql.Tx

有了这些接口, 我们就可以加一个工具函数来处理 SQL 事务了.

// internal/service/entity/dao.go
// ...

func BeginTx(ctx context.Context, db SQLCmd, opts *sql.TxOptions) (tx SQLCmd, finish func(error) error, err error) {
	switch db := db.(type) {
	case SQLBeginTx:
		tx, err = db.BeginTx(ctx, opts)
		finish = func(cause error) error { return finishTx(tx.(*sql.Tx), cause) }
	default:
		tx = db
		finish = func(cause error) error { return cause }
	}
	return
}

func finishTx(tx SQLTx, cause error) (err error) {
	if err = cause; err == nil {
		return tx.Commit()
	}
	if rollbackErr := tx.Rollback(); rollbackErr != nil {
		logrus.WithError(rollbackErr).Warnf("Rollback error")
	}
	return
}

稍微解释下这里的逻辑, 如果传入的 dbSQLBeginTx 也就是 *sql.DB 那么就开始一个新事务, 并且返回一个可以提交或回滚该事务的 finish() 函数. 如果不是, 那很可能 db 是一个在其他地方开启的事务, 那么就直接 pass, finish() 也不做任何操作.

用上面的工具函数, 可以把老的代码变成这样:

// internal/service/entity/dao.go
// ...
func (d dao) Create(ctx context.Context, e Entity) (next Entity, err error) {
	// ...
	var (
		tx     SQLCmd
		finish func(error) error
	)
	if tx, finish, err = BeginTx(ctx, d.db, nil); err != nil {
		return
	}
	defer func() { err = finish(err) }()
	// ...
}

是不是干净了不少? 不过这才刚刚开始呢.

下面要处理的是查询插入结果的代码, 明明相同的逻辑已经在 Get 中实现了一遍, 却仅仅因为这里用的是 *sql.Tx 而成员变量是 *sql.DB 就得再写一遍, 这也太说不过去了. 所以, 我们把成员变量先给换成接口类型.

// internal/service/entity/dao.go
// ...
type dao struct {
	db SQLCmd
}
// ...

接着用 Get 函数做结果查询的工作.

// internal/service/item/entity.go
// ...

func (d dao) Create(ctx context.Context, e Entity) (next Entity, err error) {
	const script = "insert into item (title, price, num) values (?, ?, ?)"
	var (
		tx     SQLCmd
		finish func(error) error
		sr     sql.Result
		id     int64
	)
	if tx, finish, err = BeginTx(ctx, d.db, nil); err != nil {
		return
	}
	defer func() { err = finish(err) }()

	if sr, err = tx.ExecContext(ctx, script, e.Title, e.Price, e.Num); err != nil {
		return
	}
	if id, err = sr.LastInsertId(); err != nil {
		return
	}
	
	// 注意这里
	sub := dao{db: tx}
	return sub.Get(ctx, strconv.FormatInt(id, 10))
}

由于 BeginTx() 中已经做好了 PassBy 的工作, 所以多层 Tx 嵌套也是没有问题的哦, 最终会被展开成一层.

接着补上对应的服务.

// internal/server/server.go
// ...

type CreateRequest struct {
	Entity
}

func (srv Service) Create(ctx context.Context, req CreateRequest) (res Entity, err error) {
	return srv.dao.Create(ctx, req.Entity)
}

最后来测试一下:

$ curl --location --request POST 'http://localhost:8080/items/' \
--header 'Content-Type: application/json' \
--data-raw '{"title": "item 5", "price": 20, "num": 100}'

# {"id":"6","title":"item 5","price":20,"num":100,"createTime":"2022-04-27T04:19:36Z"}

没问题, Let’s move on!

Delete

下面来个轻松点的接口, API 规范.

// internal/service/item/entity.go
// ...
func (d dao) Delete(ctx context.Context, req DeleteRequest) (err error) {
	const script = "delete from item where id = ?"
	var (
		sr  sql.Result
		num int64
	)
	if sr, err = d.db.ExecContext(ctx, script, req.ItemID); err != nil {
		return
	}
	if num, err = sr.RowsAffected(); err != nil {
		return
	}
	if num == 0 {
		err = errors.WithNotFound(errors.NotFound, errors.ResourceInfo{
			ResourceType: "item",
			ResourceName: "items/" + req.ItemID,
		})
	}
	return
}

这个可以说的不多, 只要记得遵循 API 设计规范, 确保最终效果上的幂等性. 做到只有首次调用成功返回 204, 后续调用返回 404.

别忘了服务的代码.

type DeleteRequest struct {
	ItemID string `param:"itemID"`
}

func (srv Service) Delete(ctx context.Context, req DeleteRequest) (code int, err error) {
	if err = srv.dao.Delete(ctx, req); err == nil {
		code = http.StatusNoContent
	}
	return
}

List

按照 API 规范, 这个接口的入参是一个分页参数, 我们先把它声明出来.

// internal/service/item/entity.go
// ...
type ListRequest struct {
	PageSize  int    `query:"pageSize"`
	PageToken string `query:"pageToken"`
}

由于 List 在 SQL 中取的字段和 Get 完全一致, 所以可以提出来. 并且也可以借助之前声明的 Scanner 接口, 抽一个公共的 scanAllFields() 函数.

// internal/service/item/entity.go
// ...
const allFields = "id, title, price, num, create_time"

type Entity struct {
	// ...
}

func (e *Entity) scanAllFields(row Scanner) error {
	return row.Scan(&e.ID, &e.Title, &e.Price, &e.Num, &e.CreateTime)
}

最后再补一个工具函数 CloseRows().

// internal/service/entity/dao.go
// ...

func CloseRows(rows *sql.Rows) {
	if closeErr := rows.Close(); closeErr != nil {
		logrus.WithError(closeErr).Warn("Close rows error")
	}
}

准备完毕, 可以开写了.

按照规范, List 接受 PageSize 和 PageToken, 返回当前分页和 NextPageToken, 这种实现不需要返回 Total, 也没有用到 Offset, 所以对数据库性能还是非常友好的. 这里我们就用 ID 来当 PageToken.

// internal/service/item/entity.go
// ...
func (d dao) List(ctx context.Context, req ListRequest) (res ListResponse, err error) {
	const script = "select " + allFields + " from item where id > ? limit ?"

	if req.PageSize == 0 {
		req.PageSize = 20
	}

	if req.PageToken == "" {
		req.PageToken = "0"
	}

	var rows *sql.Rows
	if rows, err = d.db.QueryContext(ctx, script, req.PageToken, req.PageSize); err != nil {
		return
	}

	defer CloseRows(rows)

	for rows.Next() {
		var e Entity
		if err = e.scanAllFields(rows); err != nil {
			return
		}
		res.Items = append(res.Items, e)
	}
	if err = rows.Err(); err != nil {
		return
	}

	if size := len(res.Items); size == req.PageSize {
		res.NextPageToken = strconv.FormatInt(res.Items[size-1].ID, 10)
	}
	return
}

代码量稍微有点大, 不过不存在什么重复代码, 整体上还是比较直观的. 里面处理默认值的地方先不用管, 后续的改进中会挪到其他地方去.

测试前补上服务代码:

// internal/server/server.go
// ...

type ListRequest struct {
	PageSize  int    `query:"pageSize"`
	PageToken string `query:"pageToken"`
}

type ListResponse struct {
	NextPageToken string   `json:"nextPageToken"`
	Items         []Entity `json:"items"`
}

func (srv Service) List(ctx context.Context, req ListRequest) (res ListResponse, err error) {
	return srv.dao.List(ctx, req)
}

最后依然测下看看.

$ curl --location --request GET 'http://localhost:8080/items?pageSize=2'

# {"nextPageToken":"2","items":[{"id":"1","title":"New title 1","price":20,"num":10,"createTime":"2022-04-26T07:10:59Z"},{"id":"2","title":"item 2","price":50,"num":20,"createTime":"2022-04-26T07:12:57Z"}]}

格式化之后的结果, 如果 nextPageToken 为空, 就是翻到最后了:

{
    "nextPageToken": "2",
    "items": [
        {
            "id": "1",
            "title": "New title 1",
            "price": 20,
            "num": 10,
            "createTime": "2022-04-26T07:10:59Z"
        },
        {
            "id": "2",
            "title": "item 2",
            "price": 50,
            "num": 20,
            "createTime": "2022-04-26T07:12:57Z"
        }
    ]
}

Update

好了, 终于到最后一个接口了. 先看看 API 规范. 这里我们选择实现其中比较常用的 PATCH 接口. 这个接口中要用到一个 FieldMask 类型, 先把它声明出来.

// internal/service/entity/entity.go
type FieldMask struct {
	Paths []string `json:"paths,omitempty"`
}

它的主要作用是: 过滤出传入的 JSON 中真正用于更新的字段. 这样可以显著减少客户端的工作量, (不用每次拼不同的 JSON 了). 不过, 它的实现可简单可复杂, 由于我们的 item 实体中不包含复杂或嵌套的字段, 这里就写个比较简单的实现了.

// internal/service/entity/entity.go
// ...
func (mask FieldMask) ToMap(e any) (m map[string]any, err error) {
	var (
		data    []byte
		mFields map[string]any
	)
	if data, err = json.Marshal(e); err != nil {
		return
	}
	if err = json.Unmarshal(data, &mFields); err != nil {
		return
	}

	if size := len(mask.Paths); size > 0 {
		m = make(map[string]any, size)
		for _, path := range mask.Paths {
			if value, ok := mFields[path]; ok {
				m[path] = value
			}
		}
	} else {
		delete(mFields, "id")
		m = mFields
	}
	return
}

这里首先将传进来的实体编码成 JSON, 再解析到 map[string]any 类型.

  • 如果传入了 paths, 则根据其内容生成一个子 map, 这里注意要忽略掉实体中不存在的 path
  • 如果没有 paths, 则返回除主键外所有字段

但光有这个还不够, 我们最终是要把待更新的 map[string]any 转换成可执行的 SQL update 语句. 所以再来写一个转换函数.

// internal/service/entity/entity.go
// ...
func SQLUpdate(table string, id any, fields map[string]any) (script string, args []any) {
	var sb strings.Builder
	sb.WriteString("update ")
	sb.WriteString(table)
	sb.WriteString(" set ")
	args = append(args, mapJoin(&sb, fields, ", ")...)
	sb.WriteString(" where id = ?")
	args = append(args, id)
	script = sb.String()
	return
}

func mapJoin(sb io.StringWriter, m map[string]any, sep string) (args []any) {
	var flag bool
	for field, value := range m {
		if flag {
			sb.WriteString(sep)
		} else {
			flag = true
		}
		sb.WriteString(field)
		sb.WriteString(" = ?")
		args = append(args, value)
	}
	return
}

由于标准接口中的逻辑比较固定, 我们就不上第三方 SQLBuilder 库了, 如果需要处理复杂的 JSON 类型, 大家可以自由替换.

最后我们回到 item 包下写 Update 的具体实现.

// internal/service/item/entity.go
// ...

func (d dao) Update(ctx context.Context, req UpdateRequest) (res Entity, err error) {
	var fields map[string]any
	if fields, err = req.UpdateMask.ToMap(req.Item); err != nil {
		return
	}
	script, args := SQLUpdate("item", req.ItemID, fields)

	var (
		tx     SQLCmd
		finish func(error) error
	)
	if tx, finish, err = BeginTx(ctx, d.db, nil); err != nil {
		return
	}
	defer func() { err = finish(err) }()

	if _, err = tx.ExecContext(ctx, script, args...); err != nil {
		return
	}

	sub := dao{db: tx}
	return sub.Get(ctx, req.ItemID)
}

这里的代码和 Create 非常相似, 就不重复解释了. 不过服务代码还是贴一下.

// internal/server/server.go
// ...

type UpdateRequest struct {
	ItemID     string           `param:"itemID"`
	UpdateMask entity.FieldMask `json:"updateMask"`
	Item       Entity           `json:"item"`
}

func (srv Service) Update(ctx context.Context, req UpdateRequest) (res Entity, err error) {
	return srv.dao.Update(ctx, req)
}

还有一点需要注意的是, 由于 UpdateRequest 的校验规则和其他接口不太一样, 它只需要校验 FieldMask 中存在的字段, 所以还需要改一下之前 handler 中的校验逻辑.

// internal/server/server.go
// ...

func handler[Request any, Response any](h func(context.Context, Request) (Response, error)) fiber.Handler {
	return func(c *fiber.Ctx) (err error) {
		// ...
		if err = doValidate(c, req); err != nil {
			return
		}
		// ...
	}
}

func doValidate(c *fiber.Ctx, req any) (err error) {
	// 跳过更新接口
	if c.Method() != http.MethodPatch {
		return entity.Validate.Struct(req)
	}
	return
}

把 Validate 挪个位置

// internal/service/entity/entity.go
// ...

var Validate = validator.New()

然后在数据层实现校验逻辑

// internal/service/item/entity.go
// ...

func (req UpdateRequest) validate() (err error) {
	size := len(req.UpdateMask.Paths)
	if size == 0 {
		return entity.Validate.Struct(req.Item)
	}
	
	// 转换 paths 里的字段为 validator 要求的格式
	names := make([]string, size)
	for i, path := range req.UpdateMask.Paths {
		runes := []rune(path)
		runes[0] = unicode.ToUpper(runes[0])
		names[i] = string(runes)
	}
	return entity.Validate.StructPartial(req.Item, names...)
}

// ...

func (d dao) Update(ctx context.Context, req UpdateRequest) (res Entity, err error) {
	if err = req.validate(); err != nil {
		return
	}
	// ...
}

总结

本想着只写五个标准接口篇幅应该不会很长, 没想到还是快 20k 字, 所以只能就此打住了.

想完整了解这套 API 设计思路, 还是推荐直接读 Google 的 API 设计指南, 因为文章里的代码要控制长度, 所以很多边角的情况并没有处理得特别细.

另外这部分代码其实是可以通过泛型在不同的模块间复用的, 但这部分贴出来的话很多内容和之前是类似的, 所以想了解的话可以直接看代码仓库里的 internal/service/entity/general.go 这个文件.

下一篇计划以自定义接口如何实现为引, 讲讲 SQL 与编程语言相比截然不同的编码策略, 也就是进攻式编码是怎么一回事.

感谢阅读, 欢迎留言!