Просмотр исходного кода

feat: add grpc version header to ensure client and server match (#1617)

This change adds a gprc injector that adds a version header to each grpc
request. This is then checked by the server to make sure the client and
server are compatible versions. When the grpc proto file is changed we **MUST**
update the version used in the header.
tsk 4 недель назад
Родитель
Сommit
03a5a914b8

+ 1 - 0
Cargo.lock

@@ -1322,6 +1322,7 @@ dependencies = [
  "serde_with",
  "serde_with",
  "thiserror 2.0.18",
  "thiserror 2.0.18",
  "tokio",
  "tokio",
+ "tonic 0.14.2",
  "tracing",
  "tracing",
  "url",
  "url",
  "utoipa",
  "utoipa",

+ 2 - 0
crates/cdk-common/Cargo.toml

@@ -20,6 +20,7 @@ mint = ["cashu/mint", "dep:uuid"]
 nostr = ["wallet", "cashu/nostr"]
 nostr = ["wallet", "cashu/nostr"]
 prometheus = ["cdk-prometheus/default"]
 prometheus = ["cdk-prometheus/default"]
 http = ["dep:cdk-http-client"]
 http = ["dep:cdk-http-client"]
+grpc = ["dep:tonic"]
 
 
 [dependencies]
 [dependencies]
 cdk-http-client = { workspace = true, optional = true }
 cdk-http-client = { workspace = true, optional = true }
@@ -44,6 +45,7 @@ serde_with.workspace = true
 web-time.workspace = true
 web-time.workspace = true
 parking_lot = "0.12.5"
 parking_lot = "0.12.5"
 paste = "1.0.15"
 paste = "1.0.15"
+tonic = { workspace = true, optional = true }
 
 
 [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
 [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
 tokio = { version = "1", default-features = false, features = ["rt", "rt-multi-thread", "macros", "test-util", "sync"] }
 tokio = { version = "1", default-features = false, features = ["rt", "rt-multi-thread", "macros", "test-util", "sync"] }

+ 45 - 0
crates/cdk-common/src/grpc.rs

@@ -0,0 +1,45 @@
+//! gRPC version checking utilities
+
+use tonic::{Request, Status};
+
+/// Header name for protocol version
+pub const VERSION_HEADER: &str = "x-cdk-protocol-version";
+
+/// Creates a client-side interceptor that injects a specific protocol version into outgoing requests
+///
+/// # Panics
+/// Panics if the version string is not a valid gRPC metadata ASCII value
+pub fn create_version_inject_interceptor(
+    version: &'static str,
+) -> impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone {
+    move |mut request: Request<()>| {
+        request.metadata_mut().insert(
+            VERSION_HEADER,
+            version.parse().expect("Invalid protocol version"),
+        );
+        Ok(request)
+    }
+}
+
+/// Creates a server-side interceptor that validates a specific protocol version on incoming requests
+pub fn create_version_check_interceptor(
+    expected_version: &'static str,
+) -> impl Fn(Request<()>) -> Result<Request<()>, Status> + Clone {
+    move |request: Request<()>| match request.metadata().get(VERSION_HEADER) {
+        Some(version) => {
+            let version = version
+                .to_str()
+                .map_err(|_| Status::invalid_argument("Invalid protocol version header"))?;
+            if version != expected_version {
+                return Err(Status::failed_precondition(format!(
+                    "Protocol version mismatch: server={}, client={}",
+                    expected_version, version
+                )));
+            }
+            Ok(request)
+        }
+        None => Err(Status::failed_precondition(
+            "Missing x-cdk-protocol-version header",
+        )),
+    }
+}

+ 12 - 0
crates/cdk-common/src/lib.rs

@@ -8,6 +8,18 @@
 
 
 pub mod task;
 pub mod task;
 
 
+/// Protocol version for gRPC Mint RPC communication
+pub const MINT_RPC_PROTOCOL_VERSION: &str = "1.0.0";
+
+/// Protocol version for gRPC Signatory communication
+pub const SIGNATORY_PROTOCOL_VERSION: &str = "1.0.0";
+
+/// Protocol version for gRPC Payment Processor communication
+pub const PAYMENT_PROCESSOR_PROTOCOL_VERSION: &str = "1.0.0";
+
+#[cfg(feature = "grpc")]
+pub mod grpc;
+
 pub mod common;
 pub mod common;
 pub mod database;
 pub mod database;
 pub mod error;
 pub mod error;

+ 1 - 1
crates/cdk-mint-rpc/Cargo.toml

@@ -23,7 +23,7 @@ anyhow.workspace = true
 cdk = { workspace = true, features = [
 cdk = { workspace = true, features = [
     "mint",
     "mint",
 ] }
 ] }
-cdk-common.workspace = true
+cdk-common = { workspace = true, features = ["grpc"] }
 clap.workspace = true
 clap.workspace = true
 tonic = { workspace = true, features = ["transport", "tls-ring", "codegen", "router"] }
 tonic = { workspace = true, features = ["transport", "tls-ring", "codegen", "router"] }
 tracing.workspace = true
 tracing.workspace = true

+ 15 - 1
crates/cdk-mint-rpc/src/bin/mint_rpc_cli.rs

@@ -3,12 +3,23 @@
 use std::path::PathBuf;
 use std::path::PathBuf;
 
 
 use anyhow::{anyhow, Result};
 use anyhow::{anyhow, Result};
+use cdk_common::grpc::VERSION_HEADER;
 use cdk_mint_rpc::cdk_mint_client::CdkMintClient;
 use cdk_mint_rpc::cdk_mint_client::CdkMintClient;
 use cdk_mint_rpc::mint_rpc_cli::subcommands;
 use cdk_mint_rpc::mint_rpc_cli::subcommands;
 use cdk_mint_rpc::GetInfoRequest;
 use cdk_mint_rpc::GetInfoRequest;
 use clap::{Parser, Subcommand};
 use clap::{Parser, Subcommand};
+use tonic::metadata::MetadataValue;
 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
 use tonic::Request;
 use tonic::Request;
+
+/// Helper function to add version header to a request
+fn with_version_header<T>(mut request: Request<T>) -> Request<T> {
+    request.metadata_mut().insert(
+        VERSION_HEADER,
+        MetadataValue::from_static(cdk_common::MINT_RPC_PROTOCOL_VERSION),
+    );
+    request
+}
 use tracing_subscriber::EnvFilter;
 use tracing_subscriber::EnvFilter;
 
 
 /// Common CLI arguments for CDK binaries
 /// Common CLI arguments for CDK binaries
@@ -150,11 +161,14 @@ async fn main() -> Result<()> {
             .await?
             .await?
     };
     };
 
 
+    // Create client
     let mut client = CdkMintClient::new(channel);
     let mut client = CdkMintClient::new(channel);
 
 
     match cli.command {
     match cli.command {
         Commands::GetInfo => {
         Commands::GetInfo => {
-            let response = client.get_info(Request::new(GetInfoRequest {})).await?;
+            let response = client
+                .get_info(with_version_header(Request::new(GetInfoRequest {})))
+                .await?;
             let info = response.into_inner();
             let info = response.into_inner();
             println!(
             println!(
                 "name:             {}",
                 "name:             {}",

+ 3 - 0
crates/cdk-mint-rpc/src/lib.rs

@@ -5,3 +5,6 @@ pub mod proto;
 pub mod mint_rpc_cli;
 pub mod mint_rpc_cli;
 
 
 pub use proto::*;
 pub use proto::*;
+
+/// Type alias for the CdkMintClient that works with any tower service
+pub type CdkMintClient<S> = cdk_mint_client::CdkMintClient<S>;

+ 15 - 0
crates/cdk-mint-rpc/src/mint_rpc_cli/subcommands/mod.rs

@@ -1,3 +1,18 @@
+//! Subcommands for the mint RPC CLI
+
+use cdk_common::grpc::VERSION_HEADER;
+use tonic::metadata::MetadataValue;
+use tonic::Request;
+
+/// Helper function to add version header to a request
+pub fn with_version_header<T>(mut request: Request<T>) -> Request<T> {
+    request.metadata_mut().insert(
+        VERSION_HEADER,
+        MetadataValue::from_static(cdk_common::MINT_RPC_PROTOCOL_VERSION),
+    );
+    request
+}
+
 /// Module for rotating to the next keyset
 /// Module for rotating to the next keyset
 mod rotate_next_keyset;
 mod rotate_next_keyset;
 /// Module for updating mint contact information
 /// Module for updating mint contact information

+ 10 - 8
crates/cdk-mint-rpc/src/mint_rpc_cli/subcommands/update_nut04.rs

@@ -52,14 +52,16 @@ pub async fn update_nut04(
         .map(|description| MintMethodOptions { description });
         .map(|description| MintMethodOptions { description });
 
 
     let _response = client
     let _response = client
-        .update_nut04(Request::new(UpdateNut04Request {
-            method: sub_command_args.method.clone(),
-            unit: sub_command_args.unit.clone(),
-            disabled: sub_command_args.disabled,
-            min_amount: sub_command_args.min_amount,
-            max_amount: sub_command_args.max_amount,
-            options,
-        }))
+        .update_nut04(crate::mint_rpc_cli::subcommands::with_version_header(
+            Request::new(UpdateNut04Request {
+                method: sub_command_args.method.clone(),
+                unit: sub_command_args.unit.clone(),
+                disabled: sub_command_args.disabled,
+                min_amount: sub_command_args.min_amount,
+                max_amount: sub_command_args.max_amount,
+                options,
+            }),
+        ))
         .await?;
         .await?;
 
 
     Ok(())
     Ok(())

+ 1 - 1
crates/cdk-mint-rpc/src/proto/cdk-mint-rpc.proto

@@ -1,6 +1,6 @@
 syntax = "proto3";
 syntax = "proto3";
 
 
-package cdk_mint_rpc;
+package cdk_mint_management_v1;
 
 
 service CdkMint {
 service CdkMint {
     rpc GetInfo(GetInfoRequest) returns (GetInfoResponse) {}
     rpc GetInfo(GetInfoRequest) returns (GetInfoResponse) {}

+ 3 - 1
crates/cdk-mint-rpc/src/proto/mod.rs

@@ -1,7 +1,9 @@
 //! CDK mint proto types
 //! CDK mint proto types
 
 
-tonic::include_proto!("cdk_mint_rpc");
+tonic::include_proto!("cdk_mint_management_v1");
 
 
 mod server;
 mod server;
 
 
+/// Protocol version for gRPC Mint RPC communication
+pub use cdk_common::MINT_RPC_PROTOCOL_VERSION as PROTOCOL_VERSION;
 pub use server::MintRPCServer;
 pub use server::MintRPCServer;

+ 15 - 6
crates/cdk-mint-rpc/src/proto/server.rs

@@ -9,6 +9,7 @@ use cdk::nuts::nut05::MeltMethodSettings;
 use cdk::nuts::{CurrencyUnit, MintQuoteState, PaymentMethod};
 use cdk::nuts::{CurrencyUnit, MintQuoteState, PaymentMethod};
 use cdk::types::QuoteTTL;
 use cdk::types::QuoteTTL;
 use cdk::Amount;
 use cdk::Amount;
+use cdk_common::grpc::create_version_check_interceptor;
 use cdk_common::payment::WaitPaymentResponse;
 use cdk_common::payment::WaitPaymentResponse;
 use thiserror::Error;
 use thiserror::Error;
 use tokio::sync::Notify;
 use tokio::sync::Notify;
@@ -135,13 +136,19 @@ impl MintRPCServer {
                     .identity(server_identity)
                     .identity(server_identity)
                     .client_ca_root(client_ca_cert);
                     .client_ca_root(client_ca_cert);
 
 
-                Server::builder()
-                    .tls_config(tls_config)?
-                    .add_service(CdkMintServer::new(self.clone()))
+                Server::builder().tls_config(tls_config)?.add_service(
+                    CdkMintServer::with_interceptor(
+                        self.clone(),
+                        create_version_check_interceptor(cdk_common::MINT_RPC_PROTOCOL_VERSION),
+                    ),
+                )
             }
             }
             None => {
             None => {
                 tracing::warn!("No valid TLS configuration found, starting insecure server");
                 tracing::warn!("No valid TLS configuration found, starting insecure server");
-                Server::builder().add_service(CdkMintServer::new(self.clone()))
+                Server::builder().add_service(CdkMintServer::with_interceptor(
+                    self.clone(),
+                    create_version_check_interceptor(cdk_common::MINT_RPC_PROTOCOL_VERSION),
+                ))
             }
             }
         };
         };
 
 
@@ -223,7 +230,7 @@ impl CdkMint for MintRPCServer {
             })
             })
             .collect();
             .collect();
 
 
-        Ok(Response::new(GetInfoResponse {
+        let response = Response::new(GetInfoResponse {
             name: info.name,
             name: info.name,
             description: info.description,
             description: info.description,
             long_description: info.description_long,
             long_description: info.description_long,
@@ -234,7 +241,9 @@ impl CdkMint for MintRPCServer {
             urls: info.urls.unwrap_or_default(),
             urls: info.urls.unwrap_or_default(),
             total_issued: total_issued.into(),
             total_issued: total_issued.into(),
             total_redeemed: total_redeemed.into(),
             total_redeemed: total_redeemed.into(),
-        }))
+        });
+
+        Ok(response)
     }
     }
 
 
     /// Updates the mint's message of the day
     /// Updates the mint's message of the day

+ 1 - 1
crates/cdk-payment-processor/Cargo.toml

@@ -26,7 +26,7 @@ anyhow.workspace = true
 async-trait.workspace = true
 async-trait.workspace = true
 bitcoin.workspace = true
 bitcoin.workspace = true
 cashu.workspace = true
 cashu.workspace = true
-cdk-common = { workspace = true, features = ["mint"] }
+cdk-common = { workspace = true, features = ["mint", "grpc"] }
 cdk-cln = { workspace = true, optional = true }
 cdk-cln = { workspace = true, optional = true }
 cdk-lnd = { workspace = true, optional = true }
 cdk-lnd = { workspace = true, optional = true }
 cdk-fake-wallet = { workspace = true, optional = true }
 cdk-fake-wallet = { workspace = true, optional = true }

+ 29 - 14
crates/cdk-payment-processor/src/proto/client.rs

@@ -4,6 +4,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::Arc;
 use std::sync::Arc;
 
 
 use anyhow::anyhow;
 use anyhow::anyhow;
+use cdk_common::grpc::VERSION_HEADER;
 use cdk_common::payment::{
 use cdk_common::payment::{
     CreateIncomingPaymentResponse, IncomingPaymentOptions as CdkIncomingPaymentOptions,
     CreateIncomingPaymentResponse, IncomingPaymentOptions as CdkIncomingPaymentOptions,
     MakePaymentResponse as CdkMakePaymentResponse, MintPayment,
     MakePaymentResponse as CdkMakePaymentResponse, MintPayment,
@@ -11,6 +12,7 @@ use cdk_common::payment::{
 };
 };
 use futures::{Stream, StreamExt};
 use futures::{Stream, StreamExt};
 use tokio_util::sync::CancellationToken;
 use tokio_util::sync::CancellationToken;
+use tonic::metadata::MetadataValue;
 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
 use tonic::{async_trait, Request};
 use tonic::{async_trait, Request};
 use tracing::instrument;
 use tracing::instrument;
@@ -21,6 +23,15 @@ use crate::proto::{
     IncomingPaymentOptions, MakePaymentRequest, OutgoingPaymentRequestType, PaymentQuoteRequest,
     IncomingPaymentOptions, MakePaymentRequest, OutgoingPaymentRequestType, PaymentQuoteRequest,
 };
 };
 
 
+/// Helper function to add version header to a request
+fn with_version_header<T>(mut request: Request<T>) -> Request<T> {
+    request.metadata_mut().insert(
+        VERSION_HEADER,
+        MetadataValue::from_static(cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION),
+    );
+    request
+}
+
 /// Payment Processor
 /// Payment Processor
 #[derive(Clone)]
 #[derive(Clone)]
 pub struct PaymentProcessorClient {
 pub struct PaymentProcessorClient {
@@ -96,7 +107,7 @@ impl MintPayment for PaymentProcessorClient {
     async fn get_settings(&self) -> Result<cdk_common::payment::SettingsResponse, Self::Err> {
     async fn get_settings(&self) -> Result<cdk_common::payment::SettingsResponse, Self::Err> {
         let mut inner = self.inner.clone();
         let mut inner = self.inner.clone();
         let response = inner
         let response = inner
-            .get_settings(Request::new(EmptyRequest {}))
+            .get_settings(with_version_header(Request::new(EmptyRequest {})))
             .await
             .await
             .map_err(|err| {
             .map_err(|err| {
                 tracing::error!("Could not get settings: {}", err);
                 tracing::error!("Could not get settings: {}", err);
@@ -163,10 +174,10 @@ impl MintPayment for PaymentProcessorClient {
         };
         };
 
 
         let response = inner
         let response = inner
-            .create_payment(Request::new(CreatePaymentRequest {
+            .create_payment(with_version_header(Request::new(CreatePaymentRequest {
                 unit: unit.to_string(),
                 unit: unit.to_string(),
                 options: Some(proto_options),
                 options: Some(proto_options),
-            }))
+            })))
             .await
             .await
             .map_err(|err| {
             .map_err(|err| {
                 tracing::error!("Could not create payment request: {}", err);
                 tracing::error!("Could not create payment request: {}", err);
@@ -217,13 +228,13 @@ impl MintPayment for PaymentProcessorClient {
         };
         };
 
 
         let response = inner
         let response = inner
-            .get_payment_quote(Request::new(PaymentQuoteRequest {
+            .get_payment_quote(with_version_header(Request::new(PaymentQuoteRequest {
                 request: proto_request,
                 request: proto_request,
                 unit: unit.to_string(),
                 unit: unit.to_string(),
                 options: proto_options.map(Into::into),
                 options: proto_options.map(Into::into),
                 request_type: request_type.into(),
                 request_type: request_type.into(),
                 extra_json,
                 extra_json,
-            }))
+            })))
             .await
             .await
             .map_err(|err| {
             .map_err(|err| {
                 tracing::error!("Could not get payment quote: {}", err);
                 tracing::error!("Could not get payment quote: {}", err);
@@ -282,11 +293,11 @@ impl MintPayment for PaymentProcessorClient {
         };
         };
 
 
         let response = inner
         let response = inner
-            .make_payment(Request::new(MakePaymentRequest {
+            .make_payment(with_version_header(Request::new(MakePaymentRequest {
                 payment_options: Some(payment_options),
                 payment_options: Some(payment_options),
                 partial_amount: None,
                 partial_amount: None,
                 max_fee_amount: None,
                 max_fee_amount: None,
-            }))
+            })))
             .await
             .await
             .map_err(|err| {
             .map_err(|err| {
                 tracing::error!("Could not pay payment request: {}", err);
                 tracing::error!("Could not pay payment request: {}", err);
@@ -316,7 +327,7 @@ impl MintPayment for PaymentProcessorClient {
         tracing::debug!("Client waiting for payment");
         tracing::debug!("Client waiting for payment");
         let mut inner = self.inner.clone();
         let mut inner = self.inner.clone();
         let stream = inner
         let stream = inner
-            .wait_incoming_payment(EmptyRequest {})
+            .wait_incoming_payment(with_version_header(Request::new(EmptyRequest {})))
             .await
             .await
             .map_err(|err| {
             .map_err(|err| {
                 tracing::error!("Could not check incoming payment stream: {}", err);
                 tracing::error!("Could not check incoming payment stream: {}", err);
@@ -372,9 +383,11 @@ impl MintPayment for PaymentProcessorClient {
     ) -> Result<Vec<WaitPaymentResponse>, Self::Err> {
     ) -> Result<Vec<WaitPaymentResponse>, Self::Err> {
         let mut inner = self.inner.clone();
         let mut inner = self.inner.clone();
         let response = inner
         let response = inner
-            .check_incoming_payment(Request::new(CheckIncomingPaymentRequest {
-                request_identifier: Some(payment_identifier.clone().into()),
-            }))
+            .check_incoming_payment(with_version_header(Request::new(
+                CheckIncomingPaymentRequest {
+                    request_identifier: Some(payment_identifier.clone().into()),
+                },
+            )))
             .await
             .await
             .map_err(|err| {
             .map_err(|err| {
                 tracing::error!("Could not check incoming payment: {}", err);
                 tracing::error!("Could not check incoming payment: {}", err);
@@ -395,9 +408,11 @@ impl MintPayment for PaymentProcessorClient {
     ) -> Result<CdkMakePaymentResponse, Self::Err> {
     ) -> Result<CdkMakePaymentResponse, Self::Err> {
         let mut inner = self.inner.clone();
         let mut inner = self.inner.clone();
         let response = inner
         let response = inner
-            .check_outgoing_payment(Request::new(CheckOutgoingPaymentRequest {
-                request_identifier: Some(payment_identifier.clone().into()),
-            }))
+            .check_outgoing_payment(with_version_header(Request::new(
+                CheckOutgoingPaymentRequest {
+                    request_identifier: Some(payment_identifier.clone().into()),
+                },
+            )))
             .await
             .await
             .map_err(|err| {
             .map_err(|err| {
                 tracing::error!("Could not check outgoing payment: {}", err);
                 tracing::error!("Could not check outgoing payment: {}", err);

+ 15 - 4
crates/cdk-payment-processor/src/proto/server.rs

@@ -5,6 +5,7 @@ use std::str::FromStr;
 use std::sync::Arc;
 use std::sync::Arc;
 use std::time::Duration;
 use std::time::Duration;
 
 
+use cdk_common::grpc::create_version_check_interceptor;
 use cdk_common::payment::{IncomingPaymentOptions, MintPayment};
 use cdk_common::payment::{IncomingPaymentOptions, MintPayment};
 use cdk_common::CurrencyUnit;
 use cdk_common::CurrencyUnit;
 use futures::{Stream, StreamExt};
 use futures::{Stream, StreamExt};
@@ -103,13 +104,23 @@ impl PaymentProcessorServer {
                     .identity(server_identity)
                     .identity(server_identity)
                     .client_ca_root(client_ca_cert);
                     .client_ca_root(client_ca_cert);
 
 
-                Server::builder()
-                    .tls_config(tls_config)?
-                    .add_service(CdkPaymentProcessorServer::new(self.clone()))
+                Server::builder().tls_config(tls_config)?.add_service(
+                    CdkPaymentProcessorServer::with_interceptor(
+                        self.clone(),
+                        create_version_check_interceptor(
+                            cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
+                        ),
+                    ),
+                )
             }
             }
             None => {
             None => {
                 tracing::warn!("No valid TLS configuration found, starting insecure server");
                 tracing::warn!("No valid TLS configuration found, starting insecure server");
-                Server::builder().add_service(CdkPaymentProcessorServer::new(self.clone()))
+                Server::builder().add_service(CdkPaymentProcessorServer::with_interceptor(
+                    self.clone(),
+                    create_version_check_interceptor(
+                        cdk_common::PAYMENT_PROCESSOR_PROTOCOL_VERSION,
+                    ),
+                ))
             }
             }
         };
         };
 
 

+ 1 - 0
crates/cdk-signatory/Cargo.toml

@@ -20,6 +20,7 @@ async-trait.workspace = true
 bitcoin.workspace = true
 bitcoin.workspace = true
 cdk-common = { workspace = true, default-features = false, features = [
 cdk-common = { workspace = true, default-features = false, features = [
     "mint",
     "mint",
+    "grpc",
 ] }
 ] }
 tonic = { workspace = true, optional = true, features = ["transport", "tls-ring", "codegen", "router"] }
 tonic = { workspace = true, optional = true, features = ["transport", "tls-ring", "codegen", "router"] }
 tonic-prost = { workspace = true, optional = true }
 tonic-prost = { workspace = true, optional = true }

+ 17 - 4
crates/cdk-signatory/src/proto/client.rs

@@ -1,7 +1,9 @@
 use std::path::Path;
 use std::path::Path;
 
 
 use cdk_common::error::Error;
 use cdk_common::error::Error;
+use cdk_common::grpc::VERSION_HEADER;
 use cdk_common::{BlindSignature, BlindedMessage, Proof};
 use cdk_common::{BlindSignature, BlindedMessage, Proof};
+use tonic::metadata::MetadataValue;
 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
 
 
 use crate::proto::signatory_client::SignatoryClient;
 use crate::proto::signatory_client::SignatoryClient;
@@ -34,6 +36,15 @@ pub enum ClientError {
     InvalidUrl,
     InvalidUrl,
 }
 }
 
 
+/// Helper function to add version header to a request
+fn with_version_header<T>(mut request: tonic::Request<T>) -> tonic::Request<T> {
+    request.metadata_mut().insert(
+        VERSION_HEADER,
+        MetadataValue::from_static(cdk_common::SIGNATORY_PROTOCOL_VERSION),
+    );
+    request
+}
+
 impl SignatoryRpcClient {
 impl SignatoryRpcClient {
     /// Create a new RemoteSigner from a tonic transport channel.
     /// Create a new RemoteSigner from a tonic transport channel.
     pub async fn new<A: AsRef<Path>>(url: String, tls_dir: Option<A>) -> Result<Self, ClientError> {
     pub async fn new<A: AsRef<Path>>(url: String, tls_dir: Option<A>) -> Result<Self, ClientError> {
@@ -112,7 +123,7 @@ impl Signatory for SignatoryRpcClient {
 
 
         self.client
         self.client
             .clone()
             .clone()
-            .blind_sign(req)
+            .blind_sign(with_version_header(tonic::Request::new(req)))
             .await
             .await
             .map(|response| {
             .map(|response| {
                 handle_error!(response, sigs)
                 handle_error!(response, sigs)
@@ -129,7 +140,7 @@ impl Signatory for SignatoryRpcClient {
         let req: super::Proofs = proofs.into();
         let req: super::Proofs = proofs.into();
         self.client
         self.client
             .clone()
             .clone()
-            .verify_proofs(req)
+            .verify_proofs(with_version_header(tonic::Request::new(req)))
             .await
             .await
             .map(|response| {
             .map(|response| {
                 if handle_error!(response, success, scalar) {
                 if handle_error!(response, success, scalar) {
@@ -145,7 +156,9 @@ impl Signatory for SignatoryRpcClient {
     async fn keysets(&self) -> Result<SignatoryKeysets, Error> {
     async fn keysets(&self) -> Result<SignatoryKeysets, Error> {
         self.client
         self.client
             .clone()
             .clone()
-            .keysets(super::EmptyRequest {})
+            .keysets(with_version_header(tonic::Request::new(
+                super::EmptyRequest {},
+            )))
             .await
             .await
             .map(|response| handle_error!(response, keysets).try_into())
             .map(|response| handle_error!(response, keysets).try_into())
             .map_err(|e| Error::Custom(e.to_string()))?
             .map_err(|e| Error::Custom(e.to_string()))?
@@ -156,7 +169,7 @@ impl Signatory for SignatoryRpcClient {
         let req: super::RotationRequest = args.into();
         let req: super::RotationRequest = args.into();
         self.client
         self.client
             .clone()
             .clone()
-            .rotate_keyset(req)
+            .rotate_keyset(with_version_header(tonic::Request::new(req)))
             .await
             .await
             .map(|response| handle_error!(response, keyset).try_into())
             .map(|response| handle_error!(response, keyset).try_into())
             .map_err(|e| Error::Custom(e.to_string()))?
             .map_err(|e| Error::Custom(e.to_string()))?

+ 5 - 2
crates/cdk-signatory/src/proto/server.rs

@@ -3,6 +3,7 @@ use std::net::SocketAddr;
 use std::path::Path;
 use std::path::Path;
 use std::sync::Arc;
 use std::sync::Arc;
 
 
+use cdk_common::grpc::create_version_check_interceptor;
 use tokio::io::{AsyncRead, AsyncWrite};
 use tokio::io::{AsyncRead, AsyncWrite};
 use tokio_stream::Stream;
 use tokio_stream::Stream;
 use tonic::metadata::MetadataMap;
 use tonic::metadata::MetadataMap;
@@ -267,8 +268,9 @@ where
     };
     };
 
 
     server
     server
-        .add_service(signatory_server::SignatoryServer::new(
+        .add_service(signatory_server::SignatoryServer::with_interceptor(
             CdkSignatoryServer::new(signatory_loader),
             CdkSignatoryServer::new(signatory_loader),
+            create_version_check_interceptor(cdk_common::SIGNATORY_PROTOCOL_VERSION),
         ))
         ))
         .serve(addr)
         .serve(addr)
         .await?;
         .await?;
@@ -288,8 +290,9 @@ where
     IE: Into<Box<dyn std::error::Error + Send + Sync>>,
     IE: Into<Box<dyn std::error::Error + Send + Sync>>,
 {
 {
     Server::builder()
     Server::builder()
-        .add_service(signatory_server::SignatoryServer::new(
+        .add_service(signatory_server::SignatoryServer::with_interceptor(
             CdkSignatoryServer::new(signatory_loader),
             CdkSignatoryServer::new(signatory_loader),
+            create_version_check_interceptor(cdk_common::SIGNATORY_PROTOCOL_VERSION),
         ))
         ))
         .serve_with_incoming(incoming)
         .serve_with_incoming(incoming)
         .await?;
         .await?;