123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540 |
- use futures::executor::block_on;
- use mlua::{Compiler, Lua, Table, Value};
- use std::{
- collections::HashMap,
- hash::Hash,
- sync::{
- atomic::{AtomicU16, AtomicUsize, Ordering},
- Arc,
- },
- };
- use tokio::{
- sync::{mpsc, oneshot, Mutex, RwLock},
- time::{timeout, Duration},
- };
- #[async_trait::async_trait]
- pub trait VarStorage: Send + Sync {
- async fn get(&self, instance: usize, var: Variable) -> VarValue;
- async fn set(&self, instance: usize, var: Variable, value: VarValue);
- async fn shutdown(&self, instance: usize);
- }
- type Sender<I, R> = mpsc::Sender<(Vec<I>, oneshot::Sender<R>)>;
- type Receiver<I, R> = mpsc::Receiver<(Vec<I>, oneshot::Sender<R>)>;
- #[derive(Debug)]
- pub struct Program<X>
- where
- X: VarStorage + 'static,
- {
- opcodes: Vec<u8>,
- instances: Arc<AtomicUsize>,
- running: Arc<AtomicU16>,
- execution_id: Arc<AtomicUsize>,
- storage: Arc<X>,
- sender: Sender<VarValue, VarValue>,
- receiver: Arc<Mutex<Receiver<VarValue, VarValue>>>,
- }
- #[derive(Debug, Clone)]
- pub enum VarValue {
- /// The Lua value `nil`.
- Nil,
- /// The Lua value `true` or `false`.
- Boolean(bool),
- /// Integer number
- Integer(i128),
- /// A floating point number.
- Number(f64),
- /// String
- String(String),
- /// A vector
- Vector(Vec<VarValue>),
- /// A
- HashMap(HashMap<String, VarValue>),
- /// An error
- ErrorType(String),
- }
- pub enum Variable {
- Balances,
- Accounts,
- Transactions,
- Payments,
- Other(String),
- }
- impl From<String> for Variable {
- fn from(s: String) -> Self {
- match s.as_str() {
- "balances" => Self::Balances,
- "accounts" => Self::Accounts,
- "transactions" => Self::Transactions,
- "payments" => Self::Payments,
- _ => Self::Other(s),
- }
- }
- }
- impl Variable {
- pub fn name<'a>(&'a self) -> &'a str {
- match self {
- Self::Balances => "balances",
- Self::Accounts => "accounts",
- Self::Transactions => "transactions",
- Self::Payments => "payments",
- Self::Other(s) => s,
- }
- }
- }
- impl<X> Program<X>
- where
- X: VarStorage + 'static,
- {
- pub fn new(opcodes: Vec<u8>, storage: Arc<X>) -> Program<X> {
- let (sender, receiver) = mpsc::channel(100);
- Self {
- storage,
- instances: Arc::new(AtomicUsize::new(0)),
- execution_id: Arc::new(AtomicUsize::new(0)),
- running: Arc::new(AtomicU16::new(0)),
- receiver: Arc::new(Mutex::new(receiver)),
- opcodes,
- sender,
- }
- }
- fn var_value_to_lua_val(lua: &Lua, value: VarValue) -> mlua::Result<Value> {
- match value {
- VarValue::Nil => Ok(Value::Nil),
- VarValue::Boolean(b) => Ok(Value::Boolean(b)),
- VarValue::Integer(i) => Ok(Value::Integer(i.try_into().unwrap())),
- VarValue::Number(n) => Ok(Value::Number(n)),
- VarValue::String(s) => Ok(Value::String(lua.create_string(&s)?)),
- VarValue::HashMap(map) => {
- let table = lua.create_table()?;
- for (k, v) in map {
- table.set(k, Self::var_value_to_lua_val(lua, v)?)?;
- }
- Ok(Value::Table(table))
- }
- VarValue::ErrorType(e) => Err(mlua::Error::RuntimeError(e.to_string())),
- _ => Err(mlua::Error::RuntimeError("Invalid type".into())),
- }
- }
- fn inject_dynamic_global_state(
- lua: &Lua,
- storage: Arc<X>,
- instance: usize,
- ) -> mlua::Result<Option<Table>> {
- lua.set_app_data(storage);
- let getter = lua.create_function(move |lua, (global, key): (Table, String)| {
- match global.raw_get::<_, Value>(key.clone())?.into() {
- Value::Nil => (),
- local_value => return Ok(local_value),
- };
- let storage = lua
- .app_data_ref::<Arc<X>>()
- .ok_or(mlua::Error::MismatchedRegistryKey)?
- .clone();
- let value = block_on(async move { storage.get(instance, key.into()).await });
- Self::var_value_to_lua_val(lua, value)
- })?;
- let setter =
- lua.create_function(move |lua, (global, key, value): (Table, String, Value)| {
- let storage = lua
- .app_data_ref::<Arc<X>>()
- .ok_or(mlua::Error::MismatchedRegistryKey)?
- .clone();
- let value: VarValue = if let Ok(value) = value.as_ref().try_into() {
- value
- } else {
- return global.raw_set(key, value);
- };
- block_on(async move {
- storage.set(instance, key.into(), value).await;
- Ok(())
- })
- })?;
- let metatable = lua.create_table()?;
- metatable.raw_set("__index", getter)?;
- metatable.raw_set("__newindex", setter)?;
- Ok(Some(metatable))
- }
- /// Returns a new Lua VM and a list of all the global variables to be
- /// persisted and read from the storage engine. Since lua is a dynamic
- /// language, other state variables may be read/updated dynamically, which
- /// is fine, this list is just for the initial state and any potential
- /// optimization.
- fn execute_program(state: Arc<X>, instance: usize, bytecode: &[u8]) -> mlua::Result<VarValue> {
- let lua = Lua::new();
- let globals = lua.globals();
- let require = lua.create_function(|_, (_,): (String,)| -> mlua::Result<()> {
- Err(mlua::Error::RuntimeError("require is not allowed".into()))
- })?;
- globals.set_metatable(Self::inject_dynamic_global_state(
- &lua,
- state.clone(),
- instance,
- )?);
- lua.set_memory_limit(100 * 1024 * 1024)?;
- // remove external require
- globals.set("require", require)?;
- drop(globals);
- // load main program
- let x: Value = lua.load(bytecode).call(())?;
- // shutdown the execution and let the storage / state engine know so all
- // locked variables by this execution_id can be released
- block_on(async move {
- state.shutdown(instance).await;
- });
- x.as_ref().try_into().map_err(|_| mlua::Error::StackError)
- }
- fn spawn(
- storage: Arc<X>,
- bytecode: Vec<u8>,
- instances: Arc<AtomicUsize>,
- exec_id: Arc<AtomicUsize>,
- running: Arc<AtomicU16>,
- receiver: Arc<Mutex<Receiver<VarValue, VarValue>>>,
- ) {
- if instances.load(Ordering::Relaxed) > 100 {
- return;
- }
- instances.fetch_add(1, Ordering::Relaxed);
- let max_timeout = Duration::from_secs(30);
- tokio::task::spawn_blocking(move || {
- loop {
- if let Ok(mut queue) =
- futures::executor::block_on(timeout(max_timeout, receiver.lock()))
- {
- if let Ok(Some((_inputs, output))) =
- futures::executor::block_on(timeout(max_timeout, queue.recv()))
- {
- let exec_id: usize = exec_id.fetch_add(1, Ordering::Relaxed);
- // drop queue lock to release the mutex so any other
- // free VM can use it to listen for incoming messages
- drop(queue);
- let ret =
- Self::execute_program(storage.clone(), exec_id, &bytecode).unwrap();
- running.fetch_add(1, Ordering::Relaxed);
- let _ = output.send(ret).unwrap();
- continue;
- }
- }
- break;
- }
- println!("Lua listener is exiting");
- instances.fetch_sub(1, Ordering::Relaxed);
- });
- }
- pub async fn exec(&self, input: Vec<VarValue>) -> VarValue {
- let (return_notifier, return_listener) = oneshot::channel();
- Self::spawn(
- self.storage.clone(),
- self.opcodes.clone(),
- self.instances.clone(),
- self.execution_id.clone(),
- self.running.clone(),
- self.receiver.clone(),
- );
- self.sender
- .send((input, return_notifier))
- .await
- .expect("valid");
- return_listener.await.expect("valid")
- }
- }
- impl TryFrom<&Value<'_>> for VarValue {
- type Error = String;
- fn try_from(value: &Value<'_>) -> Result<Self, Self::Error> {
- match value {
- Value::Nil => Ok(VarValue::Nil),
- Value::Boolean(b) => Ok(VarValue::Boolean(*b)),
- Value::Integer(i) => Ok(VarValue::Integer((*i).into())),
- Value::Number(n) => Ok(VarValue::Number(*n)),
- Value::String(s) => Ok(VarValue::String(s.to_str().unwrap().to_owned())),
- Value::Table(t) => {
- let mut map = HashMap::new();
- let mut iter = t.clone().pairs::<String, Value>().enumerate();
- let mut is_vector = true;
- while let Some((id, Ok((k, v)))) = iter.next() {
- if Ok(id + 1) != k.parse() {
- is_vector = false;
- }
- map.insert(k, v.as_ref().try_into()?);
- }
- Ok(if is_vector {
- let mut values = map
- .into_iter()
- .map(|(k, v)| k.parse().map(|k| (k, v)))
- .collect::<Result<Vec<(usize, VarValue)>, _>>()
- .unwrap();
- values.sort_by(|(a, _), (b, _)| a.cmp(b));
- VarValue::Vector(values.into_iter().map(|(_, v)| v).collect())
- } else {
- VarValue::HashMap(map)
- })
- }
- x => Err(format!("Invalid type: {:?}", x)),
- }
- }
- }
- #[derive(Debug)]
- pub struct VarStorageMem {
- storage: RwLock<HashMap<String, VarValue>>,
- locks: RwLock<HashMap<String, usize>>,
- var_locked_by_instance: Mutex<HashMap<usize, Vec<String>>>,
- }
- impl Default for VarStorageMem {
- fn default() -> Self {
- Self {
- storage: RwLock::new(HashMap::new()),
- locks: RwLock::new(HashMap::new()),
- var_locked_by_instance: Mutex::new(HashMap::new()),
- }
- }
- }
- impl VarStorageMem {
- async fn lock(&self, instance: usize, var: &Variable) {
- let locks = self.locks.read().await;
- let name = var.name();
- if locks.get(name).map(|v| *v) == Some(instance) {
- // The variable is already locked by this instance
- return;
- }
- drop(locks);
- loop {
- // wait here while the locked is not null or it is locked by another
- // instance
- let locks = self.locks.read().await;
- let var_lock = locks.get(name).map(|v| *v);
- if var_lock.is_none() || var_lock == Some(instance) {
- break;
- }
- drop(locks);
- tokio::time::sleep(Duration::from_micros(10)).await;
- }
- loop {
- let mut locks = self.locks.write().await;
- let var_lock = locks.get(name).map(|v| *v);
- if !var_lock.is_none() {
- if var_lock == Some(instance) {
- break;
- }
- drop(locks);
- tokio::time::sleep(Duration::from_micros(10)).await;
- continue;
- }
- locks.insert(name.to_owned(), instance);
- let mut vars_by_instance = self.var_locked_by_instance.lock().await;
- if let Some(vars) = vars_by_instance.get_mut(&instance) {
- vars.push(name.to_owned());
- } else {
- vars_by_instance.insert(instance, vec![name.to_owned()]);
- }
- }
- }
- }
- #[async_trait::async_trait]
- impl VarStorage for VarStorageMem {
- async fn get(&self, instance: usize, var: Variable) -> VarValue {
- self.lock(instance, &var).await;
- self.storage
- .read()
- .await
- .get(var.name())
- .cloned()
- .unwrap_or(VarValue::Nil)
- }
- async fn set(&self, instance: usize, var: Variable, value: VarValue) {
- self.lock(instance, &var).await;
- self.storage
- .write()
- .await
- .insert(var.name().to_owned(), value);
- }
- async fn shutdown(&self, instance: usize) {
- let mut vars_by_instance = self.var_locked_by_instance.lock().await;
- let mut locks = self.locks.write().await;
- if let Some(vars) = vars_by_instance.remove(&instance) {
- for var in vars {
- if locks.get(&var).map(|v| *v) == Some(instance) {
- locks.remove(&var);
- }
- }
- }
- }
- }
- pub struct Runtime<K: Hash + Eq, X: VarStorage + 'static> {
- vms: RwLock<HashMap<K, Program<X>>>,
- }
- impl<K: Hash + Eq, X: VarStorage + 'static> Runtime<K, X> {
- pub fn new() -> Arc<Self> {
- Arc::new(Self {
- vms: RwLock::new(HashMap::new()),
- })
- }
- pub async fn register_program(&self, name: K, program: &str, storage: Arc<X>) -> bool {
- self.register_opcodes(name, Compiler::new().compile(program), storage)
- .await
- }
- pub async fn exec(&self, id: &K) -> Option<VarValue> {
- if let Some(vm) = self.vms.read().await.get(id) {
- Some(vm.exec(vec![VarValue::Integer(1)]).await)
- } else {
- None
- }
- }
- pub async fn register_opcodes(&self, name: K, opcodes: Vec<u8>, storage: Arc<X>) -> bool {
- let mut vms = self.vms.write().await;
- vms.insert(name, Program::new(opcodes, storage)).is_some()
- }
- pub async fn shutdown(&self) {
- let mut vms = self.vms.write().await;
- vms.clear();
- }
- }
- #[tokio::main]
- async fn main() {
- use std::{sync::Arc, time::Instant};
- let mem = Arc::new(VarStorageMem::default());
- async fn do_loop(vms: Arc<Runtime<String, VarStorageMem>>) {
- // Create N threads to execute the Lua code in parallel
- let num_threads = 400;
- let (tx, mut rx) = mpsc::channel(num_threads);
- for _ in 0..num_threads {
- let vm = vms.clone();
- let tx_clone = tx.clone();
- tokio::spawn(async move {
- let start_time = Instant::now();
- let result = vm.exec(&"foo".to_owned()).await;
- // Send the result back to the main thread
- let _ = tx_clone.send(result).await;
- let elapsed_time = Instant::now() - start_time;
- // Print the elapsed time in seconds and milliseconds
- println!(
- "Elapsed time: {} seconds {} milliseconds",
- elapsed_time.as_secs(),
- elapsed_time.as_millis(),
- );
- });
- }
- drop(tx);
- loop {
- let result = rx.recv().await;
- if result.is_none() {
- break;
- }
- println!("Result: {:?}", result.unwrap());
- }
- }
- let vms = Runtime::new();
- // Compile Lua code
- let _code = r#"
- function add(a, b)
- calls = calls + 1
- print("Call from old " .. pid .. " " .. calls)
- return a + b
- end
- print("hello world " .. pid)
- "#;
- let code = r#"
- if pid == nil then
- pid = 0
- end
- pid = pid + 1
- print("hello world " .. pid)
- return true
- "#;
- let _ = vms
- .register_program("foo".to_owned(), code, mem.clone())
- .await;
- do_loop(vms.clone()).await;
- let code = r#"
- if pid == nil then
- pid = 0
- end
- pid = pid + 1
- function add(a, b)
- foo = {1,"foo"}
- print("Call from new " .. pid .. " ")
- return a + b
- end
- print("hello world " .. pid .. " = " .. add(pid, pid))
- return false
- "#;
- let y = vms
- .register_program("foo".to_owned(), code, mem.clone())
- .await;
- tokio::time::sleep(Duration::from_secs(3)).await;
- do_loop(vms.clone()).await;
- vms.shutdown().await;
- tokio::time::sleep(Duration::from_secs(1)).await;
- println!("{} {:?}", "foo", mem);
- }
|