Przeglądaj źródła

Add proper types for Addr and Register

Cesar Rodas 10 miesięcy temu
rodzic
commit
6903f7e03d

+ 1 - 0
utxo/Cargo.toml

@@ -8,6 +8,7 @@ async-trait = "0.1.73"
 bech32 = "0.11.0"
 borsh = { version = "1.3.1", features = ["derive", "bytes", "de_strict_order"] }
 chrono = { version = "0.4.31", features = ["serde"] }
+derive_more = "0.99.17"
 futures = { version = "0.3.30", optional = true }
 hmac = "0.12.1"
 num = "0.4.3"

+ 43 - 3
utxo/src/filter_expr/mod.rs

@@ -9,8 +9,45 @@ mod program;
 mod runtime;
 mod value;
 
-type Register = usize;
-type Addr = usize;
+#[derive(
+    Debug,
+    Clone,
+    Copy,
+    PartialEq,
+    Eq,
+    Hash,
+    derive_more::Display,
+    derive_more::Deref,
+    derive_more::DerefMut,
+    derive_more::From,
+    derive_more::Into,
+)]
+pub struct Register(usize);
+
+#[derive(
+    Debug,
+    Clone,
+    Copy,
+    PartialEq,
+    Eq,
+    Hash,
+    derive_more::Display,
+    derive_more::Deref,
+    derive_more::DerefMut,
+    derive_more::From,
+    derive_more::Into,
+)]
+pub struct Addr(usize);
+
+impl Addr {
+    pub fn next(&mut self) {
+        self.0 += 1;
+    }
+
+    pub fn jump_to(&mut self, addr: &Addr) {
+        self.0 = addr.0;
+    }
+}
 
 #[derive(thiserror::Error, Debug)]
 pub enum Error {
@@ -51,7 +88,10 @@ pub enum Error {
     UnexpectedExprState,
 }
 
-use std::num::ParseIntError;
+use std::{
+    num::ParseIntError,
+    ops::{Deref, DerefMut},
+};
 
 use parser::Rule;
 

+ 30 - 30
utxo/src/filter_expr/opcode.rs

@@ -110,7 +110,7 @@ pub enum OpCode {
     GE(Register, Register, Register),
 
     // Branching Operations
-    /// JLABEL <address>
+    /// LABEL <address>
     /// Adds a label to the this part of the code
     LABEL(Addr),
 
@@ -144,35 +144,35 @@ pub enum OpCode {
 impl ToString for OpCode {
     fn to_string(&self) -> String {
         match self {
-            OpCode::LOAD(r, v) => format!("LOAD {} {:?}", r, v),
-            OpCode::LOAD_EXTERNAL(r, v) => format!("LOAD_EXTERNAL {} {:?}", r, v),
-            OpCode::CPY(r1, r2) => format!("CPY {} {}", r1, r2),
-            OpCode::MOV(r1, r2) => format!("MOV {} {}", r1, r2),
-            OpCode::ADD(r1, r2, r3) => format!("ADD {} {} {}", r1, r2, r3),
-            OpCode::SUB(r1, r2, r3) => format!("SUB {} {} {}", r1, r2, r3),
-            OpCode::MUL(r1, r2, r3) => format!("MUL {} {} {}", r1, r2, r3),
-            OpCode::DIV(r1, r2, r3) => format!("DIV {} {} {}", r1, r2, r3),
-            OpCode::MOD(r1, r2, r3) => format!("MOD {} {} {}", r1, r2, r3),
-            OpCode::NEG(r1, r2) => format!("NEG {} {}", r1, r2),
-            OpCode::AND(r1, r2, r3) => format!("AND {} {} {}", r1, r2, r3),
-            OpCode::OR(r1, r2, r3) => format!("OR {} {} {}", r1, r2, r3),
-            OpCode::XOR(r1, r2, r3) => format!("XOR {} {} {}", r1, r2, r3),
-            OpCode::NOT(r1, r2) => format!("NOT {} {}", r1, r2),
-            OpCode::SHL(r1, r2, v) => format!("SHL {} {} {:?}", r1, r2, v),
-            OpCode::SHR(r1, r2, v) => format!("SHR {} {} {:?}", r1, r2, v),
-            OpCode::EQ(r1, r2, r3) => format!("EQ {} {} {}", r1, r2, r3),
-            OpCode::NE(r1, r2, r3) => format!("NE {} {} {}", r1, r2, r3),
-            OpCode::LT(r1, r2, r3) => format!("LT {} {} {}", r1, r2, r3),
-            OpCode::LE(r1, r2, r3) => format!("LE {} {} {}", r1, r2, r3),
-            OpCode::GT(r1, r2, r3) => format!("GT {} {} {}", r1, r2, r3),
-            OpCode::GE(r1, r2, r3) => format!("GE {} {} {}", r1, r2, r3),
-            OpCode::LABEL(a) => format!("LABEL {}:", a),
-            OpCode::JMP(a) => format!("JMP {}", a),
-            OpCode::JEQ(r, a) => format!("JEQ {} {}", r, a),
-            OpCode::JNE(r, a) => format!("JNE {} {}", r, a),
-            OpCode::PUSH(r) => format!("PUSH {}", r),
-            OpCode::POP(r) => format!("POP {}", r),
-            OpCode::HLT(r) => format!("HLT {}", r),
+            OpCode::LOAD(r, v) => format!("LOAD {:?} {:?}", r, v),
+            OpCode::LOAD_EXTERNAL(r, v) => format!("LOAD_EXTERNAL {:?} {:?}", r, v),
+            OpCode::CPY(r1, r2) => format!("CPY {:?} {:?}", r1, r2),
+            OpCode::MOV(r1, r2) => format!("MOV {:?} {:?}", r1, r2),
+            OpCode::ADD(r1, r2, r3) => format!("ADD {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::SUB(r1, r2, r3) => format!("SUB {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::MUL(r1, r2, r3) => format!("MUL {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::DIV(r1, r2, r3) => format!("DIV {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::MOD(r1, r2, r3) => format!("MOD {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::NEG(r1, r2) => format!("NEG {:?} {:?}", r1, r2),
+            OpCode::AND(r1, r2, r3) => format!("AND {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::OR(r1, r2, r3) => format!("OR {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::XOR(r1, r2, r3) => format!("XOR {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::NOT(r1, r2) => format!("NOT {:?} {:?}", r1, r2),
+            OpCode::SHL(r1, r2, v) => format!("SHL {:?} {:?} {:?}", r1, r2, v),
+            OpCode::SHR(r1, r2, v) => format!("SHR {:?} {:?} {:?}", r1, r2, v),
+            OpCode::EQ(r1, r2, r3) => format!("EQ {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::NE(r1, r2, r3) => format!("NE {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::LT(r1, r2, r3) => format!("LT {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::LE(r1, r2, r3) => format!("LE {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::GT(r1, r2, r3) => format!("GT {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::GE(r1, r2, r3) => format!("GE {:?} {:?} {:?}", r1, r2, r3),
+            OpCode::LABEL(a) => format!("LABEL {:?}:", a),
+            OpCode::JMP(a) => format!("JMP {:?}", a),
+            OpCode::JEQ(r, a) => format!("JEQ {:?} {:?}", r, a),
+            OpCode::JNE(r, a) => format!("JNE {:?} {:?}", r, a),
+            OpCode::PUSH(r) => format!("PUSH {:?}", r),
+            OpCode::POP(r) => format!("POP {:?}", r),
+            OpCode::HLT(r) => format!("HLT {:?}", r),
         }
     }
 }

+ 27 - 34
utxo/src/filter_expr/program.rs

@@ -35,18 +35,18 @@ impl<'a> Compiler<'a> {
     pub fn new(expr: &'a Expr) -> Self {
         Self {
             expr,
-            current_register: 0,
-            labels: 0,
+            current_register: 0.into(),
+            labels: 0.into(),
         }
     }
 
     fn next_label(&mut self) -> Addr {
-        self.labels += 1;
+        *self.labels += 1;
         self.labels
     }
 
-    fn next_register(&mut self) -> usize {
-        self.current_register += 1;
+    fn next_register(&mut self) -> Register {
+        *self.current_register += 1;
         self.current_register
     }
 
@@ -54,7 +54,7 @@ impl<'a> Compiler<'a> {
         &mut self,
         expr: &'a Expr,
         exit_label: Addr,
-    ) -> Result<(Vec<OpCode>, usize), Error> {
+    ) -> Result<(Vec<OpCode>, Register), Error> {
         let mut return_value = self.next_register();
         Ok(match expr {
             Expr::Variable(name) => (
@@ -101,7 +101,7 @@ impl<'a> Compiler<'a> {
                         for (mut term_opcodes, term_return) in iter {
                             opcodes.append(&mut term_opcodes);
                             opcodes.push(OpCode::EQ(cmp, last_value, term_return));
-                            opcodes.push(OpCode::JNE(exit_label, cmp));
+                            opcodes.push(OpCode::JNE(cmp, exit_label));
                             opcodes.push(OpCode::MOV(last_value, term_return));
                         }
 
@@ -132,8 +132,8 @@ impl<'a> Compiler<'a> {
                             opcodes.append(&mut term_opcodes);
                             opcodes.push(OpCode::MOV(return_value, term_return_value));
                             match op {
-                                ExprOp::Or => opcodes.push(OpCode::JEQ(exit_label, return_value)),
-                                ExprOp::And => opcodes.push(OpCode::JNE(exit_label, return_value)),
+                                ExprOp::Or => opcodes.push(OpCode::JEQ(return_value, exit_label)),
+                                ExprOp::And => opcodes.push(OpCode::JNE(return_value, exit_label)),
                                 _ => unreachable!(),
                             };
                         }
@@ -154,12 +154,12 @@ impl<'a> Compiler<'a> {
             .map(|x| match x {
                 OpCode::LABEL(label) => (x, pos),
                 _ => {
-                    pos = pos + 1;
+                    pos = pos + 1usize;
                     (x, pos - 1)
                 }
             })
             .filter_map(|(opcode, pos)| match opcode {
-                OpCode::LABEL(id) => Some((*id, pos)),
+                OpCode::LABEL(id) => Some((*id, pos.into())),
                 _ => None,
             })
             .collect::<HashMap<_, _>>();
@@ -170,15 +170,15 @@ impl<'a> Compiler<'a> {
             .map(|opcode| {
                 Ok(match opcode {
                     OpCode::JMP(label) => {
-                        OpCode::JMP(*used_labels.get(&label).ok_or(Error::UnknownLabel(label))?)
+                        OpCode::JMP(*used_labels.get(&label).ok_or(Error::UnknownLabel(*label))?)
                     }
                     OpCode::JEQ(register, label) => OpCode::JEQ(
                         register,
-                        *used_labels.get(&label).ok_or(Error::UnknownLabel(label))?,
+                        *used_labels.get(&label).ok_or(Error::UnknownLabel(*label))?,
                     ),
                     OpCode::JNE(register, label) => OpCode::JNE(
                         register,
-                        *used_labels.get(&label).ok_or(Error::UnknownLabel(label))?,
+                        *used_labels.get(&label).ok_or(Error::UnknownLabel(*label))?,
                     ),
                     x => x,
                 })
@@ -200,26 +200,19 @@ impl Program {
         let ast = parse_query(code)?;
 
         let opcodes = ast.where_clause.map_or_else(
-            || Ok(vec![OpCode::LOAD(0, true.into()), OpCode::HLT(0)]),
+            || {
+                Ok(vec![
+                    OpCode::LOAD(0.into(), true.into()),
+                    OpCode::HLT(0.into()),
+                ])
+            },
             |expr| Compiler::new(&expr).compile(),
         )?;
 
-        println!(
-            "{}",
-            opcodes
-                .iter()
-                .map(|x| match x {
-                    OpCode::HLT(_) | OpCode::LABEL(_) => x.to_string(),
-                    x => format!("\t{}", x.to_string()),
-                })
-                .collect::<Vec<_>>()
-                .join("\n")
-        );
-
         Ok(Self {
             dbg_opcodes: opcodes.clone(),
             opcodes: Compiler::resolve_label_to_addr(opcodes)?,
-            start_at: 0,
+            start_at: 0.into(),
             initial_register: vec![],
         })
     }
@@ -283,13 +276,13 @@ mod test {
         let program = Program {
             dbg_opcodes: vec![],
             opcodes: vec![
-                OpCode::LOAD(1, 12.into()),
-                OpCode::LOAD(2, 13.into()),
-                OpCode::ADD(3, 1, 2),
-                OpCode::ADD(4, 0, 3),
-                OpCode::HLT(4),
+                OpCode::LOAD(1.into(), 12.into()),
+                OpCode::LOAD(2.into(), 13.into()),
+                OpCode::ADD(3.into(), 1.into(), 2.into()),
+                OpCode::ADD(4.into(), 0.into(), 3.into()),
+                OpCode::HLT(4.into()),
             ],
-            start_at: 0,
+            start_at: 0.into(),
             initial_register: vec![15.into()],
         };
         let x = program.execute(None).expect("valid execution");

+ 38 - 31
utxo/src/filter_expr/runtime.rs

@@ -1,7 +1,10 @@
-use super::{opcode::OpCode, value::Value, Addr, Error};
+use super::{opcode::OpCode, value::Value, Addr, Error, Register};
 use crate::Transaction;
 use num::CheckedAdd;
-use std::{collections::VecDeque, ops::Deref};
+use std::{
+    collections::{HashMap, VecDeque},
+    ops::Deref,
+};
 
 #[derive(Debug, PartialEq, PartialOrd)]
 /// Value or reference to a value.
@@ -57,7 +60,14 @@ impl<'a> From<Value> for ValueOrRef<'a> {
 macro_rules! get {
     ($r:expr,$pos:expr) => {
         ($r.get($pos)
-            .ok_or_else(|| Error::RegisterOutOfBoundaries($pos))?)
+            .ok_or_else(|| Error::RegisterOutOfBoundaries($pos.to_owned()))?)
+    };
+}
+
+macro_rules! remove {
+    ($r:expr,$pos:expr) => {
+        ($r.remove($pos)
+            .ok_or_else(|| Error::RegisterOutOfBoundaries($pos.to_owned()))?)
     };
 }
 
@@ -66,10 +76,7 @@ macro_rules! set {
         (if let Some(previous_value) = $r.get_mut($pos) {
             *previous_value = $new_value;
         } else {
-            $r.push_back($new_value);
-            if $r.len() != $pos + 1 {
-                return Err(Error::RegisterOutOfBoundaries($pos));
-            }
+            $r.insert(($pos).to_owned(), $new_value);
         })
     };
 }
@@ -82,36 +89,36 @@ pub fn execute(
     start_at: Addr,
 ) -> Result<Value, Error> {
     let mut execution = start_at;
-    let mut registers: VecDeque<ValueOrRef> = initial_registers
+    let mut registers = initial_registers
         .iter()
-        .map(|a| a.into())
-        .collect::<VecDeque<_>>();
+        .enumerate()
+        .map(|(pos, a)| (pos.into(), a.into()))
+        .collect::<HashMap<Register, ValueOrRef>>();
 
     loop {
-        match code.get(execution).ok_or(Error::OutOfBoundaries)? {
-            OpCode::LOAD(dst, ref val) => set!(registers, *dst, val.into()),
+        match code.get(*execution).ok_or(Error::OutOfBoundaries)? {
+            OpCode::LOAD(dst, ref val) => set!(registers, dst, val.into()),
             OpCode::CPY(dst, reg2) => {
-                let value = get!(registers, *reg2).clone();
-                set!(registers, *dst, value);
+                let value = get!(registers, reg2).clone();
+                set!(registers, dst, value);
             }
             OpCode::MOV(dst, reg2) => {
-                let _ = get!(registers, *reg2);
-                let previous_value = std::mem::replace(&mut registers[*reg2], Value::Nil.into());
-                set!(registers, *dst, previous_value);
+                let previous_value = remove!(registers, reg2);
+                set!(registers, dst, previous_value);
             }
             OpCode::ADD(dst, a, b) => {
-                let new_value = get!(registers, *a)
-                    .checked_add(get!(registers, *b))
+                let new_value = get!(registers, a)
+                    .checked_add(get!(registers, b))
                     .ok_or(Error::Overflow)?
                     .into();
-                set!(registers, *dst, new_value);
+                set!(registers, dst, new_value);
             }
             OpCode::AND(dst, reg2, reg3) => {
                 let new_value = Value::Bool(
-                    get!(registers, *reg2).as_boolean()? && get!(registers, *reg3).as_boolean()?,
+                    get!(registers, reg2).as_boolean()? && get!(registers, reg3).as_boolean()?,
                 )
                 .into();
-                set!(registers, *dst, new_value);
+                set!(registers, dst, new_value);
             }
             OpCode::OR(dst1, reg2, reg3) => {
                 todo!()
@@ -120,34 +127,34 @@ pub fn execute(
                 todo!()
             }
             OpCode::NOT(dst, reg2) => {
-                let new_value = Value::Bool(!get!(registers, *reg2).as_boolean()?).into();
-                set!(registers, *dst, new_value);
+                let new_value = Value::Bool(!get!(registers, reg2).as_boolean()?).into();
+                set!(registers, dst, new_value);
             }
             OpCode::JMP(addr) => {
-                execution = *addr;
+                execution.jump_to(addr);
                 continue;
             }
             OpCode::JEQ(reg, addr) => {
-                if get!(registers, *reg).as_boolean()? {
-                    execution = *addr;
+                if get!(registers, reg).as_boolean()? {
+                    execution.jump_to(addr);
                     continue;
                 }
             }
             OpCode::JNE(reg, addr) => {
-                if !get!(registers, *reg).as_boolean()? {
-                    execution = *addr;
+                if !get!(registers, reg).as_boolean()? {
+                    execution.jump_to(addr);
                     continue;
                 }
             }
             OpCode::HLT(return_register) => {
                 return registers
-                    .remove(*return_register)
+                    .remove(return_register)
                     .map(|x| x.into())
                     .ok_or(Error::EmptyRegisters)
             }
             _ => todo!(),
         }
 
-        execution += 1;
+        execution.next();
     }
 }