Jelajahi Sumber

Add variable sharing

Cesar Rodas 1 tahun lalu
induk
melakukan
946e6c7ddf
1 mengubah file dengan 243 tambahan dan 127 penghapusan
  1. 243 127
      utxo/src/main.rs

+ 243 - 127
utxo/src/main.rs

@@ -1,10 +1,10 @@
-use mlua::{Compiler, FromLua, Function, IntoLuaMulti, Lua, Table, Value};
-use sha2::{Digest, Sha256};
+use futures::executor::block_on;
+use mlua::{Compiler, Lua, Table, Value};
 use std::{
     collections::HashMap,
     hash::Hash,
     sync::{
-        atomic::{AtomicU16, Ordering},
+        atomic::{AtomicU16, AtomicUsize, Ordering},
         Arc,
     },
 };
@@ -13,37 +13,34 @@ use tokio::{
     time::{timeout, Duration},
 };
 
-pub type ProgramId = [u8; 32];
-
 #[async_trait::async_trait]
 pub trait VarStorage: Send + Sync {
-    type Error: ToString;
+    async fn get(&self, instance: usize, var: Variable) -> VarValue;
 
-    async fn get(&self, var: Variable) -> &VarValue<Self::Error>;
+    async fn set(&self, instance: usize, var: Variable, value: VarValue);
 
-    async fn set(&mut self, var: Variable, value: VarValue<Self::Error>);
+    async fn shutdown(&self, instance: usize);
 }
 
-type Sender<I, R> = mpsc::Sender<(I, oneshot::Sender<R>)>;
-type Receiver<I, R> = mpsc::Receiver<(I, oneshot::Sender<R>)>;
+type Sender<I, R> = mpsc::Sender<(Vec<I>, oneshot::Sender<R>)>;
+type Receiver<I, R> = mpsc::Receiver<(Vec<I>, oneshot::Sender<R>)>;
 
 #[derive(Debug)]
-pub struct Program<X, I, R>
+pub struct Program<X>
 where
     X: VarStorage + 'static,
-    for<'lua> I: IntoLuaMulti<'lua> + Sync + Send + 'lua,
-    for<'lua> R: FromLua<'lua> + Clone + Sync + Send + 'lua,
 {
     opcodes: Vec<u8>,
-    instances: Arc<AtomicU16>,
+    instances: Arc<AtomicUsize>,
     running: Arc<AtomicU16>,
+    execution_id: Arc<AtomicUsize>,
     storage: Arc<X>,
-    sender: Sender<I, R>,
-    receiver: Arc<Mutex<Receiver<I, R>>>,
+    sender: Sender<VarValue, VarValue>,
+    receiver: Arc<Mutex<Receiver<VarValue, VarValue>>>,
 }
 
 #[derive(Debug, Clone)]
-pub enum VarValue<E: ToString> {
+pub enum VarValue {
     /// The Lua value `nil`.
     Nil,
     /// The Lua value `true` or `false`.
@@ -55,11 +52,11 @@ pub enum VarValue<E: ToString> {
     /// String
     String(String),
     /// A vector
-    Vector(Vec<VarValue<E>>),
+    Vector(Vec<VarValue>),
     /// A
-    HashMap(HashMap<String, VarValue<E>>),
+    HashMap(HashMap<String, VarValue>),
     /// An error
-    Error(E),
+    ErrorType(String),
 }
 
 pub enum Variable {
@@ -94,17 +91,16 @@ impl Variable {
     }
 }
 
-impl<X, I, R> Program<X, I, R>
+impl<X> Program<X>
 where
     X: VarStorage + 'static,
-    for<'lua> I: IntoLuaMulti<'lua> + Sync + Send + 'lua,
-    for<'lua> R: FromLua<'lua> + Clone + Sync + Send + 'lua,
 {
-    pub fn new(opcodes: Vec<u8>, storage: Arc<X>) -> Program<X, I, R> {
+    pub fn new(opcodes: Vec<u8>, storage: Arc<X>) -> Program<X> {
         let (sender, receiver) = mpsc::channel(100);
         Self {
             storage,
-            instances: Arc::new(AtomicU16::new(0)),
+            instances: Arc::new(AtomicUsize::new(0)),
+            execution_id: Arc::new(AtomicUsize::new(0)),
             running: Arc::new(AtomicU16::new(0)),
             receiver: Arc::new(Mutex::new(receiver)),
             opcodes,
@@ -112,27 +108,64 @@ where
         }
     }
 
-    fn dynamic_global_state(storage: Arc<X>, lua: &Lua) -> Result<Option<Table>, mlua::Error> {
-        lua.set_app_data(storage);
-
-        let getter = lua.create_async_function(|lua, (_, key): (Table, String)| async move {
-            let storage = lua.app_data_ref::<Arc<X>>().unwrap();
-            let x = storage.get(key.into()).await;
+    fn var_value_to_lua_val(lua: &Lua, value: VarValue) -> mlua::Result<Value> {
+        match value {
+            VarValue::Nil => Ok(Value::Nil),
+            VarValue::Boolean(b) => Ok(Value::Boolean(b)),
+            VarValue::Integer(i) => Ok(Value::Integer(i.try_into().unwrap())),
+            VarValue::Number(n) => Ok(Value::Number(n)),
+            VarValue::String(s) => Ok(Value::String(lua.create_string(&s)?)),
+            VarValue::HashMap(map) => {
+                let table = lua.create_table()?;
+                for (k, v) in map {
+                    table.set(k, Self::var_value_to_lua_val(lua, v)?)?;
+                }
+                Ok(Value::Table(table))
+            }
+            VarValue::ErrorType(e) => Err(mlua::Error::RuntimeError(e.to_string())),
+            _ => Err(mlua::Error::RuntimeError("Invalid type".into())),
+        }
+    }
 
-            let x = Ok(format!("foo bar -> {}", key));
+    fn inject_dynamic_global_state(
+        lua: &Lua,
+        storage: Arc<X>,
+        instance: usize,
+    ) -> mlua::Result<Option<Table>> {
+        lua.set_app_data(storage);
 
-            drop(storage);
-            x
+        let getter = lua.create_function(move |lua, (global, key): (Table, String)| {
+            match global.raw_get::<_, Value>(key.clone())?.into() {
+                Value::Nil => (),
+                local_value => return Ok(local_value),
+            };
+            let storage = lua
+                .app_data_ref::<Arc<X>>()
+                .ok_or(mlua::Error::MismatchedRegistryKey)?
+                .clone();
+            let value = block_on(async move { storage.get(instance, key.into()).await });
+            Self::var_value_to_lua_val(lua, value)
         })?;
         let setter =
-            lua.create_async_function(|_, (_, key, value): (Table, String, Value)| async move {
-                panic!("set {} -> {:?}", key, value);
-                Ok(())
+            lua.create_function(move |lua, (global, key, value): (Table, String, Value)| {
+                let storage = lua
+                    .app_data_ref::<Arc<X>>()
+                    .ok_or(mlua::Error::MismatchedRegistryKey)?
+                    .clone();
+                let value: VarValue = if let Ok(value) = value.as_ref().try_into() {
+                    value
+                } else {
+                    return global.raw_set(key, value);
+                };
+                block_on(async move {
+                    storage.set(instance, key.into(), value).await;
+                    Ok(())
+                })
             })?;
 
-        let metatable = lua.create_table().unwrap();
-        metatable.raw_set("__index", getter).unwrap();
-        metatable.raw_set("__newindex", setter).unwrap();
+        let metatable = lua.create_table()?;
+        metatable.raw_set("__index", getter)?;
+        metatable.raw_set("__newindex", setter)?;
 
         Ok(Some(metatable))
     }
@@ -142,41 +175,48 @@ where
     /// language, other state variables may be read/updated dynamically, which
     /// is fine, this list is just for the initial state and any potential
     /// optimization.
-    fn get_lua_vm(storage: Arc<X>, bytecode: &[u8]) -> (Vec<Variable>, Lua) {
+    fn execute_program(state: Arc<X>, instance: usize, bytecode: &[u8]) -> mlua::Result<VarValue> {
         let lua = Lua::new();
         let globals = lua.globals();
 
-        let require = lua
-            .create_function(|_, (_,): (String,)| -> mlua::Result<()> {
-                Err(mlua::Error::RuntimeError("require is not allowed".into()))
-            })
-            .unwrap();
+        let require = lua.create_function(|_, (_,): (String,)| -> mlua::Result<()> {
+            Err(mlua::Error::RuntimeError("require is not allowed".into()))
+        })?;
 
-        globals.set_metatable(Self::dynamic_global_state(storage, &lua).unwrap());
-        lua.set_memory_limit(100 * 1024 * 1024).unwrap();
+        globals.set_metatable(Self::inject_dynamic_global_state(
+            &lua,
+            state.clone(),
+            instance,
+        )?);
+        lua.set_memory_limit(100 * 1024 * 1024)?;
 
         // remove external require
-        globals.set("require", require).unwrap();
+        globals.set("require", require)?;
         drop(globals);
 
         // load main program
-        lua.load(bytecode).exec().unwrap();
-        panic!("cesdar");
+        let x: Value = lua.load(bytecode).call(())?;
+
+        // shutdown the execution and let the storage / state engine know so all
+        // locked variables by this execution_id can be released
+        block_on(async move {
+            state.shutdown(instance).await;
+        });
 
-        (vec![], lua)
+        x.as_ref().try_into().map_err(|_| mlua::Error::StackError)
     }
 
     fn spawn(
         storage: Arc<X>,
         bytecode: Vec<u8>,
-        instances: Arc<AtomicU16>,
+        instances: Arc<AtomicUsize>,
+        exec_id: Arc<AtomicUsize>,
         running: Arc<AtomicU16>,
-        receiver: Arc<Mutex<Receiver<I, R>>>,
+        receiver: Arc<Mutex<Receiver<VarValue, VarValue>>>,
     ) {
         if instances.load(Ordering::Relaxed) > 10 {
             return;
         }
-        let (_, lua) = Self::get_lua_vm(storage.clone(), &bytecode);
 
         instances.fetch_add(1, Ordering::Relaxed);
         let max_timeout = Duration::from_secs(30);
@@ -186,36 +226,37 @@ where
                 if let Ok(mut queue) =
                     futures::executor::block_on(timeout(max_timeout, receiver.lock()))
                 {
-                    if let Ok(Some((inputs, output))) =
+                    if let Ok(Some((_inputs, output))) =
                         futures::executor::block_on(timeout(max_timeout, queue.recv()))
                     {
+                        let exec_id: usize = exec_id.fetch_add(1, Ordering::Relaxed);
                         // drop queue lock to release the mutex so any other
                         // free VM can use it to listen for incoming messages
                         drop(queue);
 
-                        let (_, lua) = Self::get_lua_vm(storage.clone(), &bytecode);
-                        let f: Function = lua.globals().get("add").unwrap();
+                        let ret =
+                            Self::execute_program(storage.clone(), exec_id, &bytecode).unwrap();
 
                         running.fetch_add(1, Ordering::Relaxed);
-                        let ret = f.call::<I, Value>(inputs).unwrap();
-                        let _ = output.send(R::from_lua(ret, &lua).unwrap());
+                        let _ = output.send(ret).unwrap();
                         continue;
                     }
                 }
                 break;
             }
 
-            println!("Lua is exiting");
+            println!("Lua listener is exiting");
             instances.fetch_sub(1, Ordering::Relaxed);
         });
     }
 
-    pub async fn exec(&self, input: I) -> R {
+    pub async fn exec(&self, input: Vec<VarValue>) -> VarValue {
         let (return_notifier, return_listener) = oneshot::channel();
         Self::spawn(
             self.storage.clone(),
             self.opcodes.clone(),
             self.instances.clone(),
+            self.execution_id.clone(),
             self.running.clone(),
             self.receiver.clone(),
         );
@@ -227,59 +268,148 @@ where
     }
 }
 
+impl TryFrom<&Value<'_>> for VarValue {
+    type Error = String;
+
+    fn try_from(value: &Value<'_>) -> Result<Self, Self::Error> {
+        match value {
+            Value::Nil => Ok(VarValue::Nil),
+            Value::Boolean(b) => Ok(VarValue::Boolean(*b)),
+            Value::Integer(i) => Ok(VarValue::Integer((*i).into())),
+            Value::Number(n) => Ok(VarValue::Number(*n)),
+            Value::String(s) => Ok(VarValue::String(s.to_str().unwrap().to_owned())),
+            Value::Table(t) => {
+                let mut map = HashMap::new();
+                let mut iter = t.clone().pairs::<String, Value>().enumerate();
+                let mut is_vector = true;
+                while let Some((id, Ok((k, v)))) = iter.next() {
+                    if Ok(id + 1) != k.parse() {
+                        is_vector = false;
+                    }
+                    map.insert(k, v.as_ref().try_into()?);
+                }
+
+                Ok(if is_vector {
+                    let mut values = map
+                        .into_iter()
+                        .map(|(k, v)| k.parse().map(|k| (k, v)))
+                        .collect::<Result<Vec<(usize, VarValue)>, _>>()
+                        .unwrap();
+
+                    values.sort_by(|(a, _), (b, _)| a.cmp(b));
+
+                    VarValue::Vector(values.into_iter().map(|(_, v)| v).collect())
+                } else {
+                    VarValue::HashMap(map)
+                })
+            }
+            x => Err(format!("Invalid type: {:?}", x)),
+        }
+    }
+}
+
+#[derive(Debug)]
 pub struct VarStorageMem {
-    storage: HashMap<String, VarValue<mlua::Error>>,
+    storage: RwLock<HashMap<String, VarValue>>,
+    locks: RwLock<HashMap<String, usize>>,
+    var_locked_by_instance: Mutex<HashMap<usize, Vec<String>>>,
 }
 
 impl Default for VarStorageMem {
     fn default() -> Self {
         Self {
-            storage: HashMap::new(),
+            storage: RwLock::new(HashMap::new()),
+            locks: RwLock::new(HashMap::new()),
+            var_locked_by_instance: Mutex::new(HashMap::new()),
         }
     }
 }
 
-impl TryInto<VarValue<mlua::Error>> for Value<'_> {
-    type Error = String;
+impl VarStorageMem {
+    async fn lock(&self, instance: usize, var: &Variable) {
+        let locks = self.locks.read().await;
+        let name = var.name();
 
-    fn try_into(
-        self,
-    ) -> Result<VarValue<mlua::Error>, <Value<'static> as TryInto<VarValue<mlua::Error>>>::Error>
-    {
-        match self {
-            Value::Nil => Ok(VarValue::Nil),
-            Value::Boolean(b) => Ok(VarValue::Boolean(b)),
-            Value::Integer(i) => Ok(VarValue::Integer(i.into())),
-            Value::Number(n) => Ok(VarValue::Number(n)),
-            Value::String(s) => Ok(VarValue::String(s.to_str().unwrap().to_owned())),
-            Value::Table(t) => {
-                let mut map = HashMap::new();
-                let mut iter = t.pairs::<String, Value>();
-                while let Some(Ok((k, v))) = iter.next() {
-                    map.insert(k, v.try_into()?);
+        if locks.get(name).map(|v| *v) == Some(instance) {
+            // The variable is already locked by this instance
+            return;
+        }
+
+        drop(locks);
+
+        loop {
+            // wait here while the locked is not null or it is locked by another
+            // instance
+            let locks = self.locks.read().await;
+            let var_lock = locks.get(name).map(|v| *v);
+            if var_lock.is_none() || var_lock == Some(instance) {
+                break;
+            }
+            drop(locks);
+            tokio::time::sleep(Duration::from_micros(10)).await;
+        }
+
+        loop {
+            let mut locks = self.locks.write().await;
+            let var_lock = locks.get(name).map(|v| *v);
+
+            if !var_lock.is_none() {
+                if var_lock == Some(instance) {
+                    break;
                 }
-                Ok(VarValue::HashMap(map))
+                drop(locks);
+                tokio::time::sleep(Duration::from_micros(10)).await;
+                continue;
+            }
+
+            locks.insert(name.to_owned(), instance);
+
+            let mut vars_by_instance = self.var_locked_by_instance.lock().await;
+            if let Some(vars) = vars_by_instance.get_mut(&instance) {
+                vars.push(name.to_owned());
+            } else {
+                vars_by_instance.insert(instance, vec![name.to_owned()]);
             }
-            _ => Err("Invalid type".into()),
         }
     }
 }
 
 #[async_trait::async_trait]
 impl VarStorage for VarStorageMem {
-    type Error = mlua::Error;
+    async fn get(&self, instance: usize, var: Variable) -> VarValue {
+        self.lock(instance, &var).await;
+        self.storage
+            .read()
+            .await
+            .get(var.name())
+            .cloned()
+            .unwrap_or(VarValue::Nil)
+    }
 
-    async fn get(&self, var: Variable) -> &VarValue<Self::Error> {
-        self.storage.get(var.name()).unwrap_or(&VarValue::Nil)
+    async fn set(&self, instance: usize, var: Variable, value: VarValue) {
+        self.lock(instance, &var).await;
+        self.storage
+            .write()
+            .await
+            .insert(var.name().to_owned(), value);
     }
 
-    async fn set(&mut self, var: Variable, value: VarValue<Self::Error>) {
-        self.storage.insert(var.name().to_owned(), value);
+    async fn shutdown(&self, instance: usize) {
+        let mut vars_by_instance = self.var_locked_by_instance.lock().await;
+        let mut locks = self.locks.write().await;
+
+        if let Some(vars) = vars_by_instance.remove(&instance) {
+            for var in vars {
+                if locks.get(&var).map(|v| *v) == Some(instance) {
+                    locks.remove(&var);
+                }
+            }
+        }
     }
 }
 
 pub struct Runtime<K: Hash + Eq, X: VarStorage + 'static> {
-    vms: RwLock<HashMap<K, Program<X, (i64, i64), i64>>>,
+    vms: RwLock<HashMap<K, Program<X>>>,
 }
 
 impl<K: Hash + Eq, X: VarStorage + 'static> Runtime<K, X> {
@@ -289,38 +419,23 @@ impl<K: Hash + Eq, X: VarStorage + 'static> Runtime<K, X> {
         })
     }
 
-    pub async fn register_program(
-        &self,
-        name: K,
-        program: &str,
-        storage: Arc<X>,
-    ) -> (ProgramId, bool) {
+    pub async fn register_program(&self, name: K, program: &str, storage: Arc<X>) -> bool {
         self.register_opcodes(name, Compiler::new().compile(program), storage)
             .await
     }
 
-    pub async fn exec(&self, id: &K) -> Option<i64> {
+    pub async fn exec(&self, id: &K) -> Option<VarValue> {
         if let Some(vm) = self.vms.read().await.get(id) {
-            Some(vm.exec((22, 33)).await)
+            Some(vm.exec(vec![VarValue::Integer(1)]).await)
         } else {
             None
         }
     }
 
-    pub async fn register_opcodes(
-        &self,
-        name: K,
-        opcodes: Vec<u8>,
-        storage: Arc<X>,
-    ) -> (ProgramId, bool) {
+    pub async fn register_opcodes(&self, name: K, opcodes: Vec<u8>, storage: Arc<X>) -> bool {
         let mut vms = self.vms.write().await;
 
-        let mut hasher = Sha256::new();
-        hasher.update(&opcodes);
-        (
-            hasher.finalize().into(),
-            vms.insert(name, Program::new(opcodes, storage)).is_some(),
-        )
+        vms.insert(name, Program::new(opcodes, storage)).is_some()
     }
 
     pub async fn shutdown(&self) {
@@ -335,7 +450,7 @@ async fn main() {
 
     let mem = Arc::new(VarStorageMem::default());
 
-    async fn do_loop(mem: Arc<VarStorageMem>, vms: Arc<Runtime<String, VarStorageMem>>) {
+    async fn do_loop(vms: Arc<Runtime<String, VarStorageMem>>) {
         // Create N threads to execute the Lua code in parallel
         let num_threads = 400;
         let (tx, mut rx) = mpsc::channel(num_threads);
@@ -343,8 +458,6 @@ async fn main() {
             let vm = vms.clone();
             let tx_clone = tx.clone();
 
-            let result = vm.exec(&"foo".to_owned()).await;
-
             tokio::spawn(async move {
                 let start_time = Instant::now();
                 let result = vm.exec(&"foo".to_owned()).await;
@@ -370,7 +483,7 @@ async fn main() {
             if result.is_none() {
                 break;
             }
-            println!("Result: {:?}", result,);
+            println!("Result: {:?}", result.unwrap());
         }
     }
 
@@ -386,39 +499,42 @@ async fn main() {
         print("hello world " .. pid)
     "#;
     let code = r#"
-        --require("foo")
-        local foobar = "foo"
-        foo.foobar = "yy"
-        print("hello world " .. pid)
-        pid = 999
-        foo = "cesar"
+        if pid == nil then
+            pid = 0
+        end
+        pid = pid + 1
         print("hello world " .. pid)
+        return true
     "#;
 
     let _ = vms
         .register_program("foo".to_owned(), code, mem.clone())
         .await;
-    do_loop(mem.clone(), vms.clone()).await;
+    do_loop(vms.clone()).await;
 
     let code = r#"
-        calls = 0
-        pid = 1
+        if pid == nil then
+            pid = 0
+        end
+        pid = pid + 1
         function add(a, b)
-            calls = calls + 1
-            print("Call from new " .. pid .. " "  .. calls)
+            foo = {1,"foo"}
+            print("Call from new " .. pid .. " ")
             return a + b
         end
-        print("hello world " .. pid)
+        print("hello world " .. pid .. " = " .. add(pid, pid))
+        return false
     "#;
     let y = vms
         .register_program("foo".to_owned(), code, mem.clone())
         .await;
-    println!("{} {:?}", "foo", y);
     tokio::time::sleep(Duration::from_secs(3)).await;
 
-    do_loop(mem.clone(), vms.clone()).await;
+    do_loop(vms.clone()).await;
 
     vms.shutdown().await;
 
     tokio::time::sleep(Duration::from_secs(1)).await;
+
+    println!("{} {:?}", "foo", mem);
 }