ソースを参照

Signatory Loader (#777)

* Allow Signatory to be run with custom incoming stream

* Allow multiple signatories to be loaded on a server

* Fix merge conflict in server.rs

* Export SignatoryLoader

* Use unit error

* Use Arc instead of reference
David Caseria 3 週間 前
コミット
fe118d180f

+ 1 - 1
crates/cdk-signatory/src/bin/cli/mod.rs

@@ -167,7 +167,7 @@ pub async fn cli_main() -> Result<()> {
 
     let socket_addr = SocketAddr::from_str(&format!("{}:{}", args.listen_addr, args.listen_port))?;
 
-    start_grpc_server(signatory, socket_addr, certs).await?;
+    start_grpc_server(Arc::new(signatory), socket_addr, certs).await?;
 
     Ok(())
 }

+ 2 - 1
crates/cdk-signatory/src/lib.rs

@@ -13,7 +13,8 @@ mod proto;
 
 #[cfg(feature = "grpc")]
 pub use proto::{
-    client::SignatoryRpcClient, server::start_grpc_server, server::start_grpc_server_with_incoming,
+    client::SignatoryRpcClient,
+    server::{start_grpc_server, start_grpc_server_with_incoming, SignatoryLoader},
 };
 
 mod common;

+ 74 - 25
crates/cdk-signatory/src/proto/server.rs

@@ -1,9 +1,11 @@
 //! This module contains the generated gRPC server code for the Signatory service.
 use std::net::SocketAddr;
 use std::path::Path;
+use std::sync::Arc;
 
 use tokio::io::{AsyncRead, AsyncWrite};
 use tokio_stream::Stream;
+use tonic::metadata::MetadataMap;
 use tonic::transport::server::Connected;
 use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
 use tonic::{Request, Response, Status};
@@ -12,25 +14,49 @@ use crate::proto::{self, signatory_server};
 use crate::signatory::Signatory;
 
 /// The server implementation for the Signatory service.
-pub struct CdkSignatoryServer<T>
+pub struct CdkSignatoryServer<S, T>
 where
-    T: Signatory + Send + Sync + 'static,
+    S: Signatory + Send + Sync + 'static,
+    T: SignatoryLoader<S> + 'static,
 {
-    inner: T,
+    loader: T,
+    _phantom: std::marker::PhantomData<S>,
+}
+
+impl<S, T> CdkSignatoryServer<S, T>
+where
+    S: Signatory + Send + Sync + 'static,
+    T: SignatoryLoader<S> + 'static,
+{
+    pub fn new(loader: T) -> Self {
+        Self {
+            loader,
+            _phantom: std::marker::PhantomData,
+        }
+    }
+
+    async fn load_signatory(&self, metadata: &MetadataMap) -> Result<Arc<S>, Status> {
+        self.loader
+            .load_signatory(metadata)
+            .await
+            .map_err(|_| Status::internal("Failed to load signatory"))
+    }
 }
 
 #[tonic::async_trait]
-impl<T> signatory_server::Signatory for CdkSignatoryServer<T>
+impl<S, T> signatory_server::Signatory for CdkSignatoryServer<S, T>
 where
-    T: Signatory + Send + Sync + 'static,
+    S: Signatory + Send + Sync + 'static,
+    T: SignatoryLoader<S> + 'static,
 {
     #[tracing::instrument(skip_all)]
     async fn blind_sign(
         &self,
         request: Request<proto::BlindedMessages>,
     ) -> Result<Response<proto::BlindSignResponse>, Status> {
-        let result = match self
-            .inner
+        let metadata = request.metadata();
+        let signatory = self.load_signatory(metadata).await?;
+        let result = match signatory
             .blind_sign(
                 request
                     .into_inner()
@@ -64,8 +90,9 @@ where
         &self,
         request: Request<proto::Proofs>,
     ) -> Result<Response<proto::BooleanResponse>, Status> {
-        let result = match self
-            .inner
+        let metadata = request.metadata();
+        let signatory = self.load_signatory(metadata).await?;
+        let result = match signatory
             .verify_proofs(
                 request
                     .into_inner()
@@ -96,9 +123,11 @@ where
 
     async fn keysets(
         &self,
-        _request: Request<proto::EmptyRequest>,
+        request: Request<proto::EmptyRequest>,
     ) -> Result<Response<proto::KeysResponse>, Status> {
-        let result = match self.inner.keysets().await {
+        let metadata = request.metadata();
+        let signatory = self.load_signatory(metadata).await?;
+        let result = match signatory.keysets().await {
             Ok(result) => proto::KeysResponse {
                 keysets: Some(result.into()),
                 ..Default::default()
@@ -116,8 +145,9 @@ where
         &self,
         request: Request<proto::RotationRequest>,
     ) -> Result<Response<proto::KeyRotationResponse>, Status> {
-        let mint_keyset_info = match self
-            .inner
+        let metadata = request.metadata();
+        let signatory = self.load_signatory(metadata).await?;
+        let mint_keyset_info = match signatory
             .rotate_keyset(request.into_inner().try_into()?)
             .await
         {
@@ -135,6 +165,23 @@ where
     }
 }
 
+/// Trait for loading a signatory instance from gRPC metadata
+#[async_trait::async_trait]
+pub trait SignatoryLoader<S>: Send + Sync {
+    /// Loads the signatory instance based on the provided metadata.
+    async fn load_signatory(&self, metadata: &MetadataMap) -> Result<Arc<S>, ()>;
+}
+
+#[async_trait::async_trait]
+impl<T> SignatoryLoader<T> for Arc<T>
+where
+    T: Signatory + Send + Sync + 'static,
+{
+    async fn load_signatory(&self, _metadata: &MetadataMap) -> Result<Arc<T>, ()> {
+        Ok(self.clone())
+    }
+}
+
 /// Error type for the gRPC server
 #[derive(thiserror::Error, Debug)]
 pub enum Error {
@@ -147,13 +194,14 @@ pub enum Error {
 }
 
 /// Runs the signatory server
-pub async fn start_grpc_server<T, I: AsRef<Path>>(
-    signatory: T,
+pub async fn start_grpc_server<S, T, I: AsRef<Path>>(
+    signatory_loader: T,
     addr: SocketAddr,
     tls_dir: Option<I>,
 ) -> Result<(), Error>
 where
-    T: Signatory + Send + Sync + 'static,
+    S: Signatory + Send + Sync + 'static,
+    T: SignatoryLoader<S> + 'static,
 {
     tracing::info!("Starting RPC server {}", addr);
 
@@ -219,29 +267,30 @@ where
     };
 
     server
-        .add_service(signatory_server::SignatoryServer::new(CdkSignatoryServer {
-            inner: signatory,
-        }))
+        .add_service(signatory_server::SignatoryServer::new(
+            CdkSignatoryServer::new(signatory_loader),
+        ))
         .serve(addr)
         .await?;
     Ok(())
 }
 
 /// Starts the gRPC signatory server with an incoming stream of connections.
-pub async fn start_grpc_server_with_incoming<T, I, IO, IE>(
-    signatory: T,
+pub async fn start_grpc_server_with_incoming<S, T, I, IO, IE>(
+    signatory_loader: T,
     incoming: I,
 ) -> Result<(), Error>
 where
-    T: Signatory + Send + Sync + 'static,
+    S: Signatory + Send + Sync + 'static,
+    T: SignatoryLoader<S> + 'static,
     I: Stream<Item = Result<IO, IE>>,
     IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
     IE: Into<Box<dyn std::error::Error + Send + Sync>>,
 {
     Server::builder()
-        .add_service(signatory_server::SignatoryServer::new(CdkSignatoryServer {
-            inner: signatory,
-        }))
+        .add_service(signatory_server::SignatoryServer::new(
+            CdkSignatoryServer::new(signatory_loader),
+        ))
         .serve_with_incoming(incoming)
         .await?;
     Ok(())