diff --git a/rust/protobuf.go b/rust/protobuf.go index ebb1c3cb6..28a4597f9 100644 --- a/rust/protobuf.go +++ b/rust/protobuf.go @@ -22,9 +22,18 @@ var ( defaultProtobufFlags = []string{""} ) +type PluginType int + +const ( + Protobuf PluginType = iota + Grpc +) + func init() { android.RegisterModuleType("rust_protobuf", RustProtobufFactory) android.RegisterModuleType("rust_protobuf_host", RustProtobufHostFactory) + android.RegisterModuleType("rust_grpcio", RustGrpcioFactory) + android.RegisterModuleType("rust_grpcio_host", RustGrpcioHostFactory) } var _ SourceProvider = (*protobufDecorator)(nil) @@ -41,15 +50,18 @@ type protobufDecorator struct { *BaseSourceProvider Properties ProtobufProperties + plugin PluginType } func (proto *protobufDecorator) GenerateSource(ctx ModuleContext, deps PathDeps) android.Path { var protoFlags android.ProtoFlags - pluginPath := ctx.Config().HostToolPath(ctx, "protoc-gen-rust") + var pluginPath android.Path protoFlags.OutTypeFlag = "--rust_out" + outDir := android.PathForModuleOut(ctx) + + pluginPath, protoFlags = proto.setupPlugin(ctx, protoFlags, outDir) - protoFlags.Flags = append(protoFlags.Flags, " --plugin="+pluginPath.String()) protoFlags.Flags = append(protoFlags.Flags, defaultProtobufFlags...) protoFlags.Flags = append(protoFlags.Flags, proto.Properties.Proto_flags...) @@ -60,7 +72,6 @@ func (proto *protobufDecorator) GenerateSource(ctx ModuleContext, deps PathDeps) ctx.PropertyErrorf("proto", "invalid path to proto file") } - outDir := android.PathForModuleOut(ctx) stem := proto.BaseSourceProvider.getStem(ctx) // rust protobuf-codegen output .rs stemFile := android.PathForModuleOut(ctx, stem+".rs") @@ -79,6 +90,23 @@ func (proto *protobufDecorator) GenerateSource(ctx ModuleContext, deps PathDeps) return modFile } +func (proto *protobufDecorator) setupPlugin(ctx ModuleContext, protoFlags android.ProtoFlags, outDir android.ModuleOutPath) (android.Path, android.ProtoFlags) { + var pluginPath android.Path + + if proto.plugin == Protobuf { + pluginPath = ctx.Config().HostToolPath(ctx, "protoc-gen-rust") + protoFlags.Flags = append(protoFlags.Flags, "--plugin="+pluginPath.String()) + } else if proto.plugin == Grpc { + pluginPath = ctx.Config().HostToolPath(ctx, "grpc_rust_plugin") + protoFlags.Flags = append(protoFlags.Flags, "--grpc_out="+outDir.String()) + protoFlags.Flags = append(protoFlags.Flags, "--plugin=protoc-gen-grpc="+pluginPath.String()) + } else { + ctx.ModuleErrorf("Unknown protobuf plugin type requested") + } + + return pluginPath, protoFlags +} + func (proto *protobufDecorator) SourceProviderProps() []interface{} { return append(proto.BaseSourceProvider.SourceProviderProps(), &proto.Properties) } @@ -104,10 +132,34 @@ func RustProtobufHostFactory() android.Module { return module.Init() } +func RustGrpcioFactory() android.Module { + module, _ := NewRustGrpcio(android.HostAndDeviceSupported) + return module.Init() +} + +// A host-only variant of rust_protobuf. Refer to rust_protobuf for more details. +func RustGrpcioHostFactory() android.Module { + module, _ := NewRustGrpcio(android.HostSupported) + return module.Init() +} + func NewRustProtobuf(hod android.HostOrDeviceSupported) (*Module, *protobufDecorator) { protobuf := &protobufDecorator{ BaseSourceProvider: NewSourceProvider(), Properties: ProtobufProperties{}, + plugin: Protobuf, + } + + module := NewSourceProviderModule(hod, protobuf, false) + + return module, protobuf +} + +func NewRustGrpcio(hod android.HostOrDeviceSupported) (*Module, *protobufDecorator) { + protobuf := &protobufDecorator{ + BaseSourceProvider: NewSourceProvider(), + Properties: ProtobufProperties{}, + plugin: Grpc, } module := NewSourceProviderModule(hod, protobuf, false) diff --git a/rust/protobuf_test.go b/rust/protobuf_test.go index bd11a5ae3..7c3907100 100644 --- a/rust/protobuf_test.go +++ b/rust/protobuf_test.go @@ -15,8 +15,10 @@ package rust import ( - "android/soong/android" + "strings" "testing" + + "android/soong/android" ) func TestRustProtobuf(t *testing.T) { @@ -28,12 +30,41 @@ func TestRustProtobuf(t *testing.T) { source_stem: "buf", } `) - // Check that there's a rule to generate the expected output - _ = ctx.ModuleForTests("librust_proto", "android_arm64_armv8-a_source").Output("buf.rs") - // Check that libprotobuf is added as a dependency. librust_proto := ctx.ModuleForTests("librust_proto", "android_arm64_armv8-a_dylib").Module().(*Module) if !android.InList("libprotobuf", librust_proto.Properties.AndroidMkDylibs) { t.Errorf("libprotobuf dependency missing for rust_protobuf (dependency missing from AndroidMkDylibs)") } + + // Make sure the correct plugin is being used. + librust_proto_out := ctx.ModuleForTests("librust_proto", "android_arm64_armv8-a_source").Output("buf.rs") + cmd := librust_proto_out.RuleParams.Command + if w := "protoc-gen-rust"; !strings.Contains(cmd, w) { + t.Errorf("expected %q in %q", w, cmd) + } + +} + +func TestRustGrpcio(t *testing.T) { + ctx := testRust(t, ` + rust_grpcio { + name: "librust_grpcio", + proto: "buf.proto", + crate_name: "rust_grpcio", + source_stem: "buf", + } + `) + + // Check that libprotobuf is added as a dependency. + librust_grpcio_module := ctx.ModuleForTests("librust_grpcio", "android_arm64_armv8-a_dylib").Module().(*Module) + if !android.InList("libprotobuf", librust_grpcio_module.Properties.AndroidMkDylibs) { + t.Errorf("libprotobuf dependency missing for rust_grpcio (dependency missing from AndroidMkDylibs)") + } + + // Make sure the correct plugin is being used. + librust_grpcio_out := ctx.ModuleForTests("librust_grpcio", "android_arm64_armv8-a_source").Output("buf.rs") + cmd := librust_grpcio_out.RuleParams.Command + if w := "protoc-gen-grpc"; !strings.Contains(cmd, w) { + t.Errorf("expected %q in %q", w, cmd) + } } diff --git a/rust/testing.go b/rust/testing.go index 42b0da171..4a1894c05 100644 --- a/rust/testing.go +++ b/rust/testing.go @@ -132,6 +132,8 @@ func CreateTestContext() *android.TestContext { ctx.RegisterModuleType("rust_ffi_host", RustFFIHostFactory) ctx.RegisterModuleType("rust_ffi_host_shared", RustFFISharedHostFactory) ctx.RegisterModuleType("rust_ffi_host_static", RustFFIStaticHostFactory) + ctx.RegisterModuleType("rust_grpcio", RustGrpcioFactory) + ctx.RegisterModuleType("rust_grpcio_host", RustGrpcioHostFactory) ctx.RegisterModuleType("rust_proc_macro", ProcMacroFactory) ctx.RegisterModuleType("rust_protobuf", RustProtobufFactory) ctx.RegisterModuleType("rust_protobuf_host", RustProtobufHostFactory)