【AI Shift Advent Calendar 2023】connect-go から学ぶコード生成

こんにちは、AIShift バックエンドエンジニアの石井(@sugar235711)です。
本記事はAIShift Advent Calendar 2023の 4 日目の記事となります。

以前、connect-web について LT を行ったことがあり、今回はその Go 版のプラグインである connect-go についての記事です。

zennにも公開しているので見やすい方でご覧ください。

本記事では、connect-goがどのようにして proto からコードを自動生成しているのかを調査します。

※注意
この記事ではプラグインの内部実装を見るので、Connect の具体的な使い方は紹介しません。

使い方が知りたい場合はチュートリアルをご覧ください。

チュートリアル

はじめに

Connectは HTTP/1.1 または HTTP/2 上で動作し、ストリーミングを含む gRPC と gRPC-Web の互換性があるプロトコルです。

プラグインを通じて、様々な言語やライブラリに適した IF を proto から生成できます。

https://connectrpc.com/

自動生成されるファイル群について

Connect Go のプラグインについて、コードを追いながらどのようにして proto からコードを自動生成しているのかを調査します。

https://github.com/connectrpc/connect-go/tree/main

まず、Connect Go のプラグインによって出力されるファイルはxxx.connect.goという形式で出力されます。

生成されるコード自体は非常にシンプルで、proto で定義した RPC に対応する Client 及び Handler の Interface などが生成されます。

// https://github.com/connectrpc/connect-go/blob/main/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go

type PingServiceClient interface {
    // Ping sends a ping to the server to determine if it's reachable.
    Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error)
    // Fail always fails.
    Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error)
    // Sum calculates the sum of the numbers sent on the stream.
    Sum(context.Context) *connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse]
    // CountUp returns a stream of the numbers up to the given request.
    CountUp(context.Context, *connect.Request[v1.CountUpRequest]) (*connect.ServerStreamForClient[v1.CountUpResponse], error)
    // CumSum determines the cumulative sum of all the numbers sent on the stream.
    CumSum(context.Context) *connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse]
}

// ...

type PingServiceHandler interface {
    // Ping sends a ping to the server to determine if it's reachable.
    Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error)
    // Fail always fails.
    Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error)
    // Sum calculates the sum of the numbers sent on the stream.
    Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error)
    // CountUp returns a stream of the numbers up to the given request.
    CountUp(context.Context, *connect.Request[v1.CountUpRequest], *connect.ServerStream[v1.CountUpResponse]) error
    // CumSum determines the cumulative sum of all the numbers sent on the stream.
    CumSum(context.Context, *connect.BidiStream[v1.CumSumRequest, v1.CumSumResponse]) error
}

Interface で使用されているリクエスト及びレスポンスの構造体はprotoc-gen-goのプラグインを使用して生成されています。

https://pkg.go.dev/google.golang.org/grpc/cmd/protoc-gen-go-grpc

使用するプラグインや出力先などはprotocまたはbufを使用して変更することが可能です。

  • protoc
$ protoc --go_out=gen --connect-go_out=gen path/to/file.proto
  • buf
version: v1
plugins:
  - name: go
      out: gen
  - name: connect-go
    out: gen
$ buf generate
-->generated
//  gen/path/to/file.pb.go
//  gen/path/to/connectfoov1/file.connect.go

gRPC 互換を謳っている面からも当然ですが、公式で提供されている proto-gen plugin によって生成されたコードと大差ないものが生成されていることがわかります。
https://grpc.io/docs/languages/go/generated-code/

Connect Go によるファイル生成について

本題です。

リポジトリを眺めていると、このプラグインのエントリーポイントであるcmd/protoc-gen-connect-go/main.goが見つかります。

https://github.com/connectrpc/connect-go/blob/main/cmd/protoc-gen-connect-go/main.go

この中のmain関数を見てみます。

// cmd/protoc-gen-connect-go/main.go
package main

import (
 // ....
    connect "connectrpc.com/connect"
    "google.golang.org/protobuf/compiler/protogen"
    "google.golang.org/protobuf/reflect/protoreflect"
    "google.golang.org/protobuf/types/descriptorpb"
    "google.golang.org/protobuf/types/pluginpb"
)

func main() {
    // ...
    protogen.Options{}.Run(
        func(plugin *protogen.Plugin) error {
            plugin.SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
            for _, file := range plugin.Files {
                if file.Generate {
                    generate(plugin, file)
                }
            }
            return nil
        },
    )
}

大まかな流れとしては以下の通りです。

protobufのcompiler/protogenを使用して必要な構造体を生成
↓
connect-goのgenerate関数に構造体を投げて書き込み
↓
出力

import しているライブラリからも Go のファイル自体の生成やロジック部分は protobuf-go の機能をベースとして使用しており、Connect Go 側ではそれをラップしているだけのように見えます。

protogen

protogen パッケージは protoc プラグインを書くためのサポートを提供しています。

実装自体は protocの標準入力から CodeGeneratorRequest を読み込み、標準出力に CodeGeneratorResponse として吐き出す実装になっています。

https://github.com/protocolbuffers/protobuf-go/blob/2087447a6b4abbfd849dd401e284315847c01834/compiler/protogen/protogen.go#L59-L93

protogen.Options{}.Run()の中で呼び出されているgenerate関数を見ていくと、protogen が提供するGeneratedFile構造体の中に付加情報を詰めていそうなことがわかります。

// cmd/protoc-gen-connect-go/main.go
func generate(plugin *protogen.Plugin, file *protogen.File) {
  // ...
    generatedFile := plugin.NewGeneratedFile(
        file.GeneratedFilenamePrefix+generatedFilenameExtension,
        protogen.GoImportPath(path.Join(
            string(file.GoImportPath),
            string(file.GoPackageName),
        )),
    )
    generatedFile.Import(file.GoImportPath)
    generatePreamble(generatedFile, file)
    generateServiceNameConstants(generatedFile, file.Services)
    generateServiceNameVariables(generatedFile, file)
    for _, service := range file.Services {
        generateService(generatedFile, service)
    }
}
GeneratedFile

ファイルを構成する上で必要な情報を詰め込んでいる構造体です。
基本的には fmt.Fprint を通して buf フィールド に書き込んでいくことでファイルを構成しているようです。

// protobuf-go/compiler/protogen/protogen.go
// A GeneratedFile is a generated file.
type GeneratedFile struct {
    gen              *Plugin
    skip             bool
    filename         string
    goImportPath     GoImportPath
    buf              bytes.Buffer
    packageNames     map[GoImportPath]GoPackageName
    usedPackageNames map[GoPackageName]bool
    manualImports    map[GoImportPath]bool
    annotations      map[string][]Annotation
}

// NewGeneratedFile creates a new generated file with the given filename
// and import path.
func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
    g := &GeneratedFile{
        gen:              gen,
        filename:         filename,
        goImportPath:     goImportPath,
        packageNames:     make(map[GoImportPath]GoPackageName),
        usedPackageNames: make(map[GoPackageName]bool),
        manualImports:    make(map[GoImportPath]bool),
        annotations:      make(map[string][]Annotation),
    }

    // All predeclared identifiers in Go are already used.
    for _, s := range types.Universe.Names() {
        g.usedPackageNames[GoPackageName(s)] = true
    }

    gen.genFiles = append(gen.genFiles, g)
    return g
}

それでは実際に書き込み処理を行っている関数を見ていきます。
今回はgenerate関数のgenerateService関数内で呼び出されているgenerateServerInterface関数に着目します。

// cmd/protoc-gen-connect-go/main.go
func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Service, names names) {
    wrapComments(g, names.Server, " is an implementation of the ", service.Desc.FullName(), " service.")
    if isDeprecatedService(service) {
        g.P("//")
        deprecated(g)
    }
    g.AnnotateSymbol(names.Server, protogen.Annotation{Location: service.Location})
    g.P("type ", names.Server, " interface {")
    for _, method := range service.Methods {
        leadingComments(
            g,
            method.Comments.Leading,
            isDeprecatedMethod(method),
        )
        g.AnnotateSymbol(names.Server+"."+method.GoName, protogen.Annotation{Location: method.Location})
        g.P(serverSignature(g, method))
    }
    g.P("}")
    g.P()
}

この関数を眺めてみるとg.P("type ", names.Server, " interface {")の部分から、必要な情報が文字列結合されて Handler 等の Interface が生成されていそうなことがわかります。

身も蓋もない話ですが、自動生成ではあらかじめ用意していた文字列リテラルと外から注入される文字列を結合してそれをファイルとして出力しているだけです。
中身を紐解いていくと特別な処理はしていないことがわかってきますね。

では次にg.Pが何をしているのかを見ていきます。

g.P自体は受け取った文字列をfmt.Fprintに通して string から bytes.Buffer に変換して Buffer に突っ込んでいくレシーバーとして定義されています。

// protobuf-go/compiler/protogen/protogen.go
// P prints a line to the generated output. It converts each parameter to a
// string following the same rules as fmt.Print. It never inserts spaces
// between parameters.
func (g *GeneratedFile) P(v ...interface{}) {
    for _, x := range v {
        switch x := x.(type) {
        case GoIdent:
            fmt.Fprint(&g.buf, g.QualifiedGoIdent(x))
        default:
            fmt.Fprint(&g.buf, x)
        }
    }
    fmt.Fprintln(&g.buf)
}
fmt.Fprint

余談ですがfmt.Fprintは第一引数でio.Writerを受け取る関数なので、io.Writerを実装しているbytes.Bufferfmt.Fprintの第一引数に渡すことができます。

func Fprint(w io.Writer, a ...any) (n int, err error) {
    p := newPrinter()
    p.doPrint(a)
    n, err = w.Write(p.buf)
    p.free()
    return
}

generateServerInterfaceに戻ります。g.Pと同様にコメントを文字列として結合しているwrapComments関数があります。

少し冗長に見えますが、処理としては以下のようになっています。

  1. 最初の for で一度必要な情報を全て Buffer に書き込んで、文字列に再度変換する
  2. func (b *Buffer) Reset()を使用して Buffer を空にする
  3. リセットした Buffer はそのまま使い回し、次の for 内で WriteString を使用して文字列を書き込む
// cmd/protoc-gen-connect-go/main.go
// Raggedy comments in the generated code are driving me insane. This
// word-wrapping function is ruinously inefficient, but it gets the job done.
func wrapComments(g *protogen.GeneratedFile, elems ...any) {
    text := &bytes.Buffer{}
    for _, el := range elems {
        switch el := el.(type) {
        case protogen.GoIdent:
            fmt.Fprint(text, g.QualifiedGoIdent(el))
        default:
            fmt.Fprint(text, el)
        }
    }
    words := strings.Fields(text.String())
    text.Reset()
    var pos int
    for _, word := range words {
        numRunes := utf8.RuneCountInString(word)
        if pos > 0 && pos+numRunes+1 > commentWidth {
            g.P("// ", text.String())
            text.Reset()
            pos = 0
        }
        if pos > 0 {
            text.WriteRune(' ')
            pos++
        }
        text.WriteString(word)
        pos += numRunes
    }
    if text.Len() > 0 {
        g.P("// ", text.String())
    }
}
  1. 最初の for で一度必要な情報を全て Buffer に書き込んで、文字列に再度変換する

これは大量の入力に対してそのまま文字列のまま結合を行うとすると、新しい文字列のアロケートが頻繁に発生してしまうため、メモリ効率が悪いです。
なので、一度全てbytes.Bufferに書き込んで最後に文字列として取り出す実装にしていると思われます。

text := &bytes.Buffer{}
for _, el := range elems {
    switch el := el.(type) {
    case protogen.GoIdent:
        fmt.Fprint(text, g.QualifiedGoIdent(el))
    default:
        fmt.Fprint(text, el)
    }
}
words := strings.Fields(text.String())
文字列結合のパフォーマンス

2023 年 12 月現在アロケーションの回数に変化はありそうですが、
キャパシティ指定付き[]byte がパフォーマンスは良さそうです。
https://qiita.com/ono_matope/items/d5e70d8a9ff2b54d5c37#comment-c163b00eb629db616a7e

  1. func (b *Buffer) Reset()を使用して Buffer を空にする

しれっと書かれているtext.Reset()ですが、結構重要で、bytes.Bufferは内部で[]byteを使用しており、Resetを呼び出すと[]byteの中身を空にすることができます。
その際に[]byteの容量(cap)は変更されないため、後続の処理で再度WriteStringを呼び出すと、再度[]byteのアロケートを行わずに済みます。

func (b *Buffer) Reset() {
    b.buf = b.buf[:0]
    b.off = 0
    b.lastRead = opInvalid
}
details 容量(cap)は変更されないとは

Go のスライスは内部的には実体の配列を指すポインタと長さ(len)と容量(cap)を持っています。

type slice struct {
    array unsafe.Pointer
    len   int
    cap   int
}

すごく雑にいうと、容量(cap)のみを指定してスライスを作成すると、ゼロ値も何も含まれない領域が確保され、容量を超えない書き込みは再度配列の確保を行わずに済みます。
逆に言えば、容量を超える書き込みを行うと再度配列の確保を行う必要があります。
なので、wrapComments 内では既に確保されている配列を再利用するために、Reset を呼び出しています。

詳しくは下記を参照してください。
https://tenntenn.dev/ja/posts/qiita-5229bce80ddb688a708a/

  1. リセットした Buffer はそのまま使い回し、次の for 内で WriteString を使用して文字列を書き込む

既に容量が確保してある text(bytes.Buffer) に WriteRuneWriteString を使用して文字列を書き込んでいます。

var pos int
for _, word := range words {
    numRunes := utf8.RuneCountInString(word)
    if pos > 0 && pos+numRunes+1 > commentWidth {
        g.P("// ", text.String())
        text.Reset()
        pos = 0
    }
    if pos > 0 {
        text.WriteRune(' ')
        pos++
    }
    text.WriteString(word)
    pos += numRunes
}
if text.Len() > 0 {
    g.P("// ", text.String())
}
io パッケージの WriteXXX について

io パッケージには WriteXXX というメソッドがあります。

func (b *Buffer) WriteString(s string) (n int, err error) {
    b.lastRead = opInvalid
    m, ok := b.tryGrowByReslice(len(s))
    if !ok {
        m = b.grow(len(s))
    }
    return copy(b.buf[m:], s), nil
}

その内部ではtryGrowByResliceが呼び出されており、ここで長さの再計算が行われています。
具体的にはnが元々確保してある cap に収まりきるなら、lを開始位置としてl+nまでのスライスを返します。(ゼロ値で初期化されているため、l+nまでのスライスはnの長さを持つスライスとなります。)

func (b *Buffer) tryGrowByReslice(n int) (int, bool) {
    if l := len(b.buf); n <= cap(b.buf)-l {
        b.buf = b.buf[:l+n]
        return l, true
    }
    return 0, false
}

このように領域の割り当てに関しては io パッケージのレイヤーでも工夫されていることがわかります。

以上より、Proto からの Go コード生成の仕組みは単純ですが、内部的にはメモリ効率を考慮した実装がなされていることが確認できました。

おわりに

ここまで読んでいただきありがとうございました。
AI Shift の開発チームでは、AI チームと連携して AI/LLM を活用したプロダクト開発を通し、日々ユーザのみなさまにより素晴らしい価値・体験を届けるべく開発に取り組んでいます。

AI Shift ではエンジニアの採用に力を入れています!この分野に少しでも興味を持っていただけましたら、カジュアル面談でお話しませんか?(オンライン・19 時以降の面談も可能です!)
【面談フォームはこちら】

明日の Advent Calendar 5 日目の記事は、開発チームの栗崎によるフロントエンド関連の記事の予定です。こちらもよろしくお願いいたします。

PICK UP

TAG