main.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. use futures::executor::block_on;
  2. use mlua::{Compiler, Lua, Table, Value};
  3. use std::{
  4. collections::HashMap,
  5. hash::Hash,
  6. sync::{
  7. atomic::{AtomicU16, AtomicUsize, Ordering},
  8. Arc,
  9. },
  10. };
  11. use tokio::{
  12. sync::{mpsc, oneshot, Mutex, RwLock},
  13. time::{timeout, Duration},
  14. };
  15. #[async_trait::async_trait]
  16. pub trait VarStorage: Send + Sync {
  17. async fn get(&self, instance: usize, var: Variable) -> VarValue;
  18. async fn set(&self, instance: usize, var: Variable, value: VarValue);
  19. async fn shutdown(&self, instance: usize);
  20. }
  21. type Sender<I, R> = mpsc::Sender<(Vec<I>, oneshot::Sender<R>)>;
  22. type Receiver<I, R> = mpsc::Receiver<(Vec<I>, oneshot::Sender<R>)>;
  23. #[derive(Debug)]
  24. pub struct Program<X>
  25. where
  26. X: VarStorage + 'static,
  27. {
  28. opcodes: Vec<u8>,
  29. instances: Arc<AtomicUsize>,
  30. running: Arc<AtomicU16>,
  31. execution_id: Arc<AtomicUsize>,
  32. storage: Arc<X>,
  33. sender: Sender<VarValue, VarValue>,
  34. receiver: Arc<Mutex<Receiver<VarValue, VarValue>>>,
  35. }
  36. #[derive(Debug, Clone)]
  37. pub enum VarValue {
  38. /// The Lua value `nil`.
  39. Nil,
  40. /// The Lua value `true` or `false`.
  41. Boolean(bool),
  42. /// Integer number
  43. Integer(i128),
  44. /// A floating point number.
  45. Number(f64),
  46. /// String
  47. String(String),
  48. /// A vector
  49. Vector(Vec<VarValue>),
  50. /// A
  51. HashMap(HashMap<String, VarValue>),
  52. /// An error
  53. ErrorType(String),
  54. }
  55. pub enum Variable {
  56. Balances,
  57. Accounts,
  58. Transactions,
  59. Payments,
  60. Other(String),
  61. }
  62. impl From<String> for Variable {
  63. fn from(s: String) -> Self {
  64. match s.as_str() {
  65. "balances" => Self::Balances,
  66. "accounts" => Self::Accounts,
  67. "transactions" => Self::Transactions,
  68. "payments" => Self::Payments,
  69. _ => Self::Other(s),
  70. }
  71. }
  72. }
  73. impl Variable {
  74. pub fn name<'a>(&'a self) -> &'a str {
  75. match self {
  76. Self::Balances => "balances",
  77. Self::Accounts => "accounts",
  78. Self::Transactions => "transactions",
  79. Self::Payments => "payments",
  80. Self::Other(s) => s,
  81. }
  82. }
  83. }
  84. impl<X> Program<X>
  85. where
  86. X: VarStorage + 'static,
  87. {
  88. pub fn new(opcodes: Vec<u8>, storage: Arc<X>) -> Program<X> {
  89. let (sender, receiver) = mpsc::channel(100);
  90. Self {
  91. storage,
  92. instances: Arc::new(AtomicUsize::new(0)),
  93. execution_id: Arc::new(AtomicUsize::new(0)),
  94. running: Arc::new(AtomicU16::new(0)),
  95. receiver: Arc::new(Mutex::new(receiver)),
  96. opcodes,
  97. sender,
  98. }
  99. }
  100. fn var_value_to_lua_val(lua: &Lua, value: VarValue) -> mlua::Result<Value> {
  101. match value {
  102. VarValue::Nil => Ok(Value::Nil),
  103. VarValue::Boolean(b) => Ok(Value::Boolean(b)),
  104. VarValue::Integer(i) => Ok(Value::Integer(i.try_into().unwrap())),
  105. VarValue::Number(n) => Ok(Value::Number(n)),
  106. VarValue::String(s) => Ok(Value::String(lua.create_string(&s)?)),
  107. VarValue::HashMap(map) => {
  108. let table = lua.create_table()?;
  109. for (k, v) in map {
  110. table.set(k, Self::var_value_to_lua_val(lua, v)?)?;
  111. }
  112. Ok(Value::Table(table))
  113. }
  114. VarValue::ErrorType(e) => Err(mlua::Error::RuntimeError(e.to_string())),
  115. _ => Err(mlua::Error::RuntimeError("Invalid type".into())),
  116. }
  117. }
  118. fn inject_dynamic_global_state(
  119. lua: &Lua,
  120. storage: Arc<X>,
  121. instance: usize,
  122. ) -> mlua::Result<Option<Table>> {
  123. lua.set_app_data(storage);
  124. let getter = lua.create_function(move |lua, (global, key): (Table, String)| {
  125. match global.raw_get::<_, Value>(key.clone())?.into() {
  126. Value::Nil => (),
  127. local_value => return Ok(local_value),
  128. };
  129. let storage = lua
  130. .app_data_ref::<Arc<X>>()
  131. .ok_or(mlua::Error::MismatchedRegistryKey)?
  132. .clone();
  133. let value = block_on(async move { storage.get(instance, key.into()).await });
  134. Self::var_value_to_lua_val(lua, value)
  135. })?;
  136. let setter =
  137. lua.create_function(move |lua, (global, key, value): (Table, String, Value)| {
  138. let storage = lua
  139. .app_data_ref::<Arc<X>>()
  140. .ok_or(mlua::Error::MismatchedRegistryKey)?
  141. .clone();
  142. let value: VarValue = if let Ok(value) = value.as_ref().try_into() {
  143. value
  144. } else {
  145. return global.raw_set(key, value);
  146. };
  147. block_on(async move {
  148. storage.set(instance, key.into(), value).await;
  149. Ok(())
  150. })
  151. })?;
  152. let metatable = lua.create_table()?;
  153. metatable.raw_set("__index", getter)?;
  154. metatable.raw_set("__newindex", setter)?;
  155. Ok(Some(metatable))
  156. }
  157. /// Returns a new Lua VM and a list of all the global variables to be
  158. /// persisted and read from the storage engine. Since lua is a dynamic
  159. /// language, other state variables may be read/updated dynamically, which
  160. /// is fine, this list is just for the initial state and any potential
  161. /// optimization.
  162. fn execute_program(state: Arc<X>, instance: usize, bytecode: &[u8]) -> mlua::Result<VarValue> {
  163. let lua = Lua::new();
  164. let globals = lua.globals();
  165. let require = lua.create_function(|_, (_,): (String,)| -> mlua::Result<()> {
  166. Err(mlua::Error::RuntimeError("require is not allowed".into()))
  167. })?;
  168. globals.set_metatable(Self::inject_dynamic_global_state(
  169. &lua,
  170. state.clone(),
  171. instance,
  172. )?);
  173. lua.set_memory_limit(100 * 1024 * 1024)?;
  174. // remove external require
  175. globals.set("require", require)?;
  176. drop(globals);
  177. // load main program
  178. let x: Value = lua.load(bytecode).call(())?;
  179. // shutdown the execution and let the storage / state engine know so all
  180. // locked variables by this execution_id can be released
  181. block_on(async move {
  182. state.shutdown(instance).await;
  183. });
  184. x.as_ref().try_into().map_err(|_| mlua::Error::StackError)
  185. }
  186. fn spawn(
  187. storage: Arc<X>,
  188. bytecode: Vec<u8>,
  189. instances: Arc<AtomicUsize>,
  190. exec_id: Arc<AtomicUsize>,
  191. running: Arc<AtomicU16>,
  192. receiver: Arc<Mutex<Receiver<VarValue, VarValue>>>,
  193. ) {
  194. if instances.load(Ordering::Relaxed) > 100 {
  195. return;
  196. }
  197. instances.fetch_add(1, Ordering::Relaxed);
  198. let max_timeout = Duration::from_secs(30);
  199. tokio::task::spawn_blocking(move || {
  200. loop {
  201. if let Ok(mut queue) =
  202. futures::executor::block_on(timeout(max_timeout, receiver.lock()))
  203. {
  204. if let Ok(Some((_inputs, output))) =
  205. futures::executor::block_on(timeout(max_timeout, queue.recv()))
  206. {
  207. let exec_id: usize = exec_id.fetch_add(1, Ordering::Relaxed);
  208. // drop queue lock to release the mutex so any other
  209. // free VM can use it to listen for incoming messages
  210. drop(queue);
  211. let ret =
  212. Self::execute_program(storage.clone(), exec_id, &bytecode).unwrap();
  213. running.fetch_add(1, Ordering::Relaxed);
  214. let _ = output.send(ret).unwrap();
  215. continue;
  216. }
  217. }
  218. break;
  219. }
  220. println!("Lua listener is exiting");
  221. instances.fetch_sub(1, Ordering::Relaxed);
  222. });
  223. }
  224. pub async fn exec(&self, input: Vec<VarValue>) -> VarValue {
  225. let (return_notifier, return_listener) = oneshot::channel();
  226. Self::spawn(
  227. self.storage.clone(),
  228. self.opcodes.clone(),
  229. self.instances.clone(),
  230. self.execution_id.clone(),
  231. self.running.clone(),
  232. self.receiver.clone(),
  233. );
  234. self.sender
  235. .send((input, return_notifier))
  236. .await
  237. .expect("valid");
  238. return_listener.await.expect("valid")
  239. }
  240. }
  241. impl TryFrom<&Value<'_>> for VarValue {
  242. type Error = String;
  243. fn try_from(value: &Value<'_>) -> Result<Self, Self::Error> {
  244. match value {
  245. Value::Nil => Ok(VarValue::Nil),
  246. Value::Boolean(b) => Ok(VarValue::Boolean(*b)),
  247. Value::Integer(i) => Ok(VarValue::Integer((*i).into())),
  248. Value::Number(n) => Ok(VarValue::Number(*n)),
  249. Value::String(s) => Ok(VarValue::String(s.to_str().unwrap().to_owned())),
  250. Value::Table(t) => {
  251. let mut map = HashMap::new();
  252. let mut iter = t.clone().pairs::<String, Value>().enumerate();
  253. let mut is_vector = true;
  254. while let Some((id, Ok((k, v)))) = iter.next() {
  255. if Ok(id + 1) != k.parse() {
  256. is_vector = false;
  257. }
  258. map.insert(k, v.as_ref().try_into()?);
  259. }
  260. Ok(if is_vector {
  261. let mut values = map
  262. .into_iter()
  263. .map(|(k, v)| k.parse().map(|k| (k, v)))
  264. .collect::<Result<Vec<(usize, VarValue)>, _>>()
  265. .unwrap();
  266. values.sort_by(|(a, _), (b, _)| a.cmp(b));
  267. VarValue::Vector(values.into_iter().map(|(_, v)| v).collect())
  268. } else {
  269. VarValue::HashMap(map)
  270. })
  271. }
  272. x => Err(format!("Invalid type: {:?}", x)),
  273. }
  274. }
  275. }
  276. #[derive(Debug)]
  277. pub struct VarStorageMem {
  278. storage: RwLock<HashMap<String, VarValue>>,
  279. locks: RwLock<HashMap<String, usize>>,
  280. var_locked_by_instance: Mutex<HashMap<usize, Vec<String>>>,
  281. }
  282. impl Default for VarStorageMem {
  283. fn default() -> Self {
  284. Self {
  285. storage: RwLock::new(HashMap::new()),
  286. locks: RwLock::new(HashMap::new()),
  287. var_locked_by_instance: Mutex::new(HashMap::new()),
  288. }
  289. }
  290. }
  291. impl VarStorageMem {
  292. async fn lock(&self, instance: usize, var: &Variable) {
  293. let locks = self.locks.read().await;
  294. let name = var.name();
  295. if locks.get(name).map(|v| *v) == Some(instance) {
  296. // The variable is already locked by this instance
  297. return;
  298. }
  299. drop(locks);
  300. loop {
  301. // wait here while the locked is not null or it is locked by another
  302. // instance
  303. let locks = self.locks.read().await;
  304. let var_lock = locks.get(name).map(|v| *v);
  305. if var_lock.is_none() || var_lock == Some(instance) {
  306. break;
  307. }
  308. drop(locks);
  309. tokio::time::sleep(Duration::from_micros(10)).await;
  310. }
  311. loop {
  312. let mut locks = self.locks.write().await;
  313. let var_lock = locks.get(name).map(|v| *v);
  314. if !var_lock.is_none() {
  315. if var_lock == Some(instance) {
  316. break;
  317. }
  318. drop(locks);
  319. tokio::time::sleep(Duration::from_micros(10)).await;
  320. continue;
  321. }
  322. locks.insert(name.to_owned(), instance);
  323. let mut vars_by_instance = self.var_locked_by_instance.lock().await;
  324. if let Some(vars) = vars_by_instance.get_mut(&instance) {
  325. vars.push(name.to_owned());
  326. } else {
  327. vars_by_instance.insert(instance, vec![name.to_owned()]);
  328. }
  329. }
  330. }
  331. }
  332. #[async_trait::async_trait]
  333. impl VarStorage for VarStorageMem {
  334. async fn get(&self, instance: usize, var: Variable) -> VarValue {
  335. self.lock(instance, &var).await;
  336. self.storage
  337. .read()
  338. .await
  339. .get(var.name())
  340. .cloned()
  341. .unwrap_or(VarValue::Nil)
  342. }
  343. async fn set(&self, instance: usize, var: Variable, value: VarValue) {
  344. self.lock(instance, &var).await;
  345. self.storage
  346. .write()
  347. .await
  348. .insert(var.name().to_owned(), value);
  349. }
  350. async fn shutdown(&self, instance: usize) {
  351. let mut vars_by_instance = self.var_locked_by_instance.lock().await;
  352. let mut locks = self.locks.write().await;
  353. if let Some(vars) = vars_by_instance.remove(&instance) {
  354. for var in vars {
  355. if locks.get(&var).map(|v| *v) == Some(instance) {
  356. locks.remove(&var);
  357. }
  358. }
  359. }
  360. }
  361. }
  362. pub struct Runtime<K: Hash + Eq, X: VarStorage + 'static> {
  363. vms: RwLock<HashMap<K, Program<X>>>,
  364. }
  365. impl<K: Hash + Eq, X: VarStorage + 'static> Runtime<K, X> {
  366. pub fn new() -> Arc<Self> {
  367. Arc::new(Self {
  368. vms: RwLock::new(HashMap::new()),
  369. })
  370. }
  371. pub async fn register_program(&self, name: K, program: &str, storage: Arc<X>) -> bool {
  372. self.register_opcodes(name, Compiler::new().compile(program), storage)
  373. .await
  374. }
  375. pub async fn exec(&self, id: &K) -> Option<VarValue> {
  376. if let Some(vm) = self.vms.read().await.get(id) {
  377. Some(vm.exec(vec![VarValue::Integer(1)]).await)
  378. } else {
  379. None
  380. }
  381. }
  382. pub async fn register_opcodes(&self, name: K, opcodes: Vec<u8>, storage: Arc<X>) -> bool {
  383. let mut vms = self.vms.write().await;
  384. vms.insert(name, Program::new(opcodes, storage)).is_some()
  385. }
  386. pub async fn shutdown(&self) {
  387. let mut vms = self.vms.write().await;
  388. vms.clear();
  389. }
  390. }
  391. #[tokio::main]
  392. async fn main() {
  393. use std::{sync::Arc, time::Instant};
  394. let mem = Arc::new(VarStorageMem::default());
  395. async fn do_loop(vms: Arc<Runtime<String, VarStorageMem>>) {
  396. // Create N threads to execute the Lua code in parallel
  397. let num_threads = 400;
  398. let (tx, mut rx) = mpsc::channel(num_threads);
  399. for _ in 0..num_threads {
  400. let vm = vms.clone();
  401. let tx_clone = tx.clone();
  402. tokio::spawn(async move {
  403. let start_time = Instant::now();
  404. let result = vm.exec(&"foo".to_owned()).await;
  405. // Send the result back to the main thread
  406. let _ = tx_clone.send(result).await;
  407. let elapsed_time = Instant::now() - start_time;
  408. // Print the elapsed time in seconds and milliseconds
  409. println!(
  410. "Elapsed time: {} seconds {} milliseconds",
  411. elapsed_time.as_secs(),
  412. elapsed_time.as_millis(),
  413. );
  414. });
  415. }
  416. drop(tx);
  417. loop {
  418. let result = rx.recv().await;
  419. if result.is_none() {
  420. break;
  421. }
  422. println!("Result: {:?}", result.unwrap());
  423. }
  424. }
  425. let vms = Runtime::new();
  426. // Compile Lua code
  427. let _code = r#"
  428. function add(a, b)
  429. calls = calls + 1
  430. print("Call from old " .. pid .. " " .. calls)
  431. return a + b
  432. end
  433. print("hello world " .. pid)
  434. "#;
  435. let code = r#"
  436. if pid == nil then
  437. pid = 0
  438. end
  439. pid = pid + 1
  440. print("hello world " .. pid)
  441. return true
  442. "#;
  443. let _ = vms
  444. .register_program("foo".to_owned(), code, mem.clone())
  445. .await;
  446. do_loop(vms.clone()).await;
  447. let code = r#"
  448. if pid == nil then
  449. pid = 0
  450. end
  451. pid = pid + 1
  452. function add(a, b)
  453. foo = {1,"foo"}
  454. print("Call from new " .. pid .. " ")
  455. return a + b
  456. end
  457. print("hello world " .. pid .. " = " .. add(pid, pid))
  458. return false
  459. "#;
  460. let y = vms
  461. .register_program("foo".to_owned(), code, mem.clone())
  462. .await;
  463. tokio::time::sleep(Duration::from_secs(3)).await;
  464. do_loop(vms.clone()).await;
  465. vms.shutdown().await;
  466. tokio::time::sleep(Duration::from_secs(1)).await;
  467. println!("{} {:?}", "foo", mem);
  468. }