Ver código fonte

Merge pull request #4 from crodas/feature/client-list

Added `client list` support
César D. Rodas 3 anos atrás
pai
commit
92aa212ff9
4 arquivos alterados com 61 adições e 26 exclusões
  1. 10 2
      src/cmd/client.rs
  2. 42 19
      src/connection.rs
  3. 2 2
      src/macros.rs
  4. 7 3
      src/server.rs

+ 10 - 2
src/cmd/client.rs

@@ -1,7 +1,8 @@
 use crate::{connection::Connection, error::Error, option, value::Value};
+use std::sync::Arc;
 use bytes::Bytes;
 
-pub fn client(conn: &mut Connection, args: &[Bytes]) -> Result<Value, Error> {
+pub fn client(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     let sub = unsafe { std::str::from_utf8_unchecked(&args[1]) }.to_string();
 
     let expected = match sub.to_lowercase().as_str() {
@@ -19,7 +20,14 @@ pub fn client(conn: &mut Connection, args: &[Bytes]) -> Result<Value, Error> {
     match sub.to_lowercase().as_str() {
         "id" => Ok((conn.id() as i64).into()),
         "info" => Ok(conn.info().as_str().into()),
-        "getname" => Ok(option!(conn.name().to_owned())),
+        "getname" => Ok(option!(conn.name())),
+        "list" => {
+            let mut v: Vec<Value> = vec![];
+            conn.all_connections().iter(&mut |conn: Arc<Connection>| {
+                v.push(conn.info().as_str().into())
+            });
+            Ok(v.into())
+        },
         "setname" => {
             let name = unsafe { std::str::from_utf8_unchecked(&args[2]) }.to_string();
             conn.set_name(name);

+ 42 - 19
src/connection.rs

@@ -1,42 +1,56 @@
 use crate::db::Db;
 use std::collections::BTreeMap;
 use std::net::SocketAddr;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, RwLock};
 
 pub struct Connections {
-    connections: BTreeMap<u128, Arc<Mutex<Connection>>>,
-    counter: u128,
+    connections: RwLock<BTreeMap<u128, Arc<Connection>>>,
+    counter: RwLock<u128>,
 }
 
 impl Connections {
     pub fn new() -> Self {
         Self {
-            counter: 0,
-            connections: BTreeMap::new(),
+            counter: RwLock::new(0),
+            connections: RwLock::new(BTreeMap::new()),
         }
     }
 
-    pub fn new_connection(&mut self, db: Arc<Db>, addr: SocketAddr) -> Arc<Mutex<Connection>> {
-        let id = self.counter;
-        let conn = Arc::new(Mutex::new(Connection {
-            id,
+    pub fn remove(self: Arc<Connections>, conn: Arc<Connection>) {
+        let id = conn.id();
+        self.connections.write().unwrap().remove(&id);
+    }
+
+    pub fn new_connection(self: &Arc<Connections>, db: Arc<Db>, addr: SocketAddr) -> Arc<Connection> {
+        let mut id = self.counter.write().unwrap();
+
+        let conn = Arc::new(Connection {
+            id: *id,
             db,
             addr,
+            connections: self.clone(),
             current_db: 0,
-            name: None,
-        }));
-        self.counter += 1;
-        self.connections.insert(id, conn.clone());
+            name: RwLock::new(None),
+        });
+        self.connections.write().unwrap().insert(*id, conn.clone());
+        *id += 1;
         conn
     }
+
+    pub fn iter(&self, f: &mut dyn FnMut(Arc<Connection>)) {
+        for (_, value) in self.connections.read().unwrap().iter() {
+            f(value.clone())
+        }
+    }
 }
 
 pub struct Connection {
     id: u128,
     db: Arc<Db>,
     current_db: u32,
+    connections: Arc<Connections>,
     addr: SocketAddr,
-    name: Option<String>,
+    name: RwLock<Option<String>>,
 }
 
 impl Connection {
@@ -48,12 +62,21 @@ impl Connection {
         self.id
     }
 
-    pub fn name(&self) -> &Option<String> {
-        &self.name
+    pub fn destroy(self: Arc<Connection>) {
+        self.connections.clone().remove(self);
+    }
+
+    pub fn all_connections(&self) -> Arc<Connections> {
+        self.connections.clone()
+    }
+
+    pub fn name(&self) -> Option<String> {
+        self.name.read().unwrap().clone()
     }
 
-    pub fn set_name(&mut self, name: String) {
-        self.name = Some(name);
+    pub fn set_name(&self, name: String) {
+        let mut r = self.name.write().unwrap();
+        *r = Some(name);
     }
 
     #[allow(dead_code)]
@@ -64,7 +87,7 @@ impl Connection {
     pub fn info(&self) -> String {
         format!(
             "id={} addr={} name={:?} db={}\r\n",
-            self.id, self.addr, self.name, self.current_db
+            self.id, self.addr, self.name.read().unwrap(), self.current_db
         )
     }
 }

+ 2 - 2
src/macros.rs

@@ -29,7 +29,7 @@ macro_rules! dispatcher {
                 }
 
                 impl ExecutableCommand for Command {
-                    fn execute(&self, conn: &mut Connection, args: &[Bytes]) -> Result<Value, Error> {
+                    fn execute(&self, conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
                         $handler(conn, args)
                     }
 
@@ -55,7 +55,7 @@ macro_rules! dispatcher {
         use std::ops::Deref;
 
         pub trait ExecutableCommand {
-            fn execute(&self, conn: &mut Connection, args: &[Bytes]) -> Result<Value, Error>;
+            fn execute(&self, conn: &Connection, args: &[Bytes]) -> Result<Value, Error>;
 
             fn check_number_args(&self, n: usize) -> bool;
 

+ 7 - 3
src/server.rs

@@ -51,7 +51,7 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
     info!("Listening on: {}", addr);
 
     let db = Arc::new(Db::new(1000));
-    let mut all_connections = Connections::new();
+    let all_connections = Arc::new(Connections::new());
 
     let db_for_purging = db.clone();
     tokio::spawn(async move {
@@ -69,7 +69,7 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
                 tokio::spawn(async move {
                     let mut transport = Framed::new(socket, RedisParser);
 
-                    trace!("New connection {}", conn.lock().unwrap().id());
+                    trace!("New connection {}", conn.id());
 
                     while let Some(result) = transport.next().await {
                         match result {
@@ -77,7 +77,7 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
                                 Ok(handler) => {
                                     let r = handler
                                         .deref()
-                                        .execute(&mut conn.lock().unwrap(), &args)
+                                        .execute(&conn, &args)
                                         .unwrap_or_else(|x| x.into());
                                     if transport.send(r).await.is_err() {
                                         break;
@@ -95,9 +95,13 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
                             }
                         }
                     }
+
+                    conn.destroy();
                 });
             }
             Err(e) => println!("error accepting socket; error = {:?}", e),
         }
+
+
     }
 }