A Virtual Machine

Our virtual machine (VM) simulates a small computer with three parts:

  1. An instruction pointer (IP) that holds the address of the next instruction to execute, initialized to 0 at startup.

  2. Four general-purpose registers R0–R3 that instructions read and write directly. All arithmetic goes through registers; there are no memory-to-memory operations.

  3. 256 words of RAM that holds both the program and its data. Addresses are one byte wide, which is why 256 words is a natural size.

Architecture Constants

i
/-- Number of general-purpose registers. -/
def NUM_REG : Nat := 4

/-- Number of addressable words in RAM. -/
def RAM_LEN : Nat := 256

/-- Mask to select one byte from a packed instruction word. -/
def OP_MASK : UInt32 := 0xFF

/-- Number of bits to shift when packing or unpacking instruction bytes. -/
def OP_SHIFT : UInt32 := 8
i
def OP_HLT : UInt32 := 0x1  -- halt the program
def OP_LDC : UInt32 := 0x2  -- load a constant value into a register
def OP_LDR : UInt32 := 0x3  -- load from the RAM address held in a register
def OP_CPY : UInt32 := 0x4  -- copy one register's value into another
def OP_STR : UInt32 := 0x5  -- store a register value into RAM at address in register
def OP_ADD : UInt32 := 0x6  -- add two registers, writing result to first
def OP_SUB : UInt32 := 0x7  -- subtract second register from first
def OP_BEQ : UInt32 := 0x8  -- branch to address if register equals zero
def OP_BNE : UInt32 := 0x9  -- branch to address if register is not zero
def OP_PRR : UInt32 := 0xA  -- append register value to output
def OP_PRM : UInt32 := 0xB  -- append RAM value at address in register to output

VM State

i
/-- The complete state of the virtual machine at one moment in time. -/
structure VMState where
  ip     : Nat           -- address of the next instruction to fetch
  reg    : Array UInt32  -- NUM_REG general-purpose registers
  ram    : Array UInt32  -- RAM_LEN words of addressable memory
  output : List String   -- lines produced by prr and prm instructions
  deriving Repr
i
/-- Create a fresh VM state with the given program loaded starting at address 0.
    Returns none if the program is longer than RAM_LEN words. -/
def VMState.init (program : Array UInt32) : Option VMState :=
  if program.size > RAM_LEN then none
  else
    let ram := (Array.range RAM_LEN).map fun i => program.getD i 0
    some { ip := 0, reg := Array.mkArray NUM_REG 0, ram, output := [] }

Fetching Instructions

i
/-- Read the instruction at the current IP, advance IP by one, and return
    the decoded (op, arg0, arg1) triple together with the updated state.
    Instructions are packed as three bytes: op in bits 7-0, arg0 in 15-8,
    arg1 in 23-16. -/
def VMState.fetch (s : VMState) : (UInt32 × UInt32 × UInt32) × VMState :=
  let instr := s.ram[s.ip]!
  let op    := instr &&& OP_MASK
  let arg0  := (instr >>> OP_SHIFT) &&& OP_MASK
  let arg1  := (instr >>> (OP_SHIFT * 2)) &&& OP_MASK
  ((op, arg0, arg1), { s with ip := s.ip + 1 })

Executing Instructions

i
/-- Execute one instruction and return the updated state,
    or none if the instruction was hlt. -/
def VMState.step (s : VMState) : Option VMState :=
  let ((op, arg0, arg1), s') := s.fetch
  let i0 := arg0.toNat
  let i1 := arg1.toNat
  -- mccole: hlt
  if op == OP_HLT then none
  -- mccole: /hlt
  -- mccole: ldc
  else if op == OP_LDC then
    some { s' with reg := s'.reg.set! i0 arg1 }
  -- mccole: /ldc
  else if op == OP_LDR then
    some { s' with reg := s'.reg.set! i0 (s'.ram[s'.reg[i1]!.toNat]!) }
  else if op == OP_CPY then
    some { s' with reg := s'.reg.set! i0 s'.reg[i1]! }
  -- mccole: str
  else if op == OP_STR then
    some { s' with ram := s'.ram.set! (s'.reg[i1]!.toNat) s'.reg[i0]! }
  -- mccole: /str
  -- mccole: add
  else if op == OP_ADD then
    some { s' with reg := s'.reg.set! i0 (s'.reg[i0]! + s'.reg[i1]!) }
  -- mccole: /add
  else if op == OP_SUB then
    some { s' with reg := s'.reg.set! i0 (s'.reg[i0]! - s'.reg[i1]!) }
  -- mccole: beq
  else if op == OP_BEQ then
    if s'.reg[i0]! == 0 then some { s' with ip := i1 }
    else some s'
  -- mccole: /beq
  else if op == OP_BNE then
    if s'.reg[i0]! != 0 then some { s' with ip := i1 }
    else some s'
  else if op == OP_PRR then
    some { s' with output := s'.output ++ [s!"{s'.reg[i0]!}"] }
  else if op == OP_PRM then
    some { s' with output := s'.output ++ [s!"{s'.ram[s'.reg[i0]!.toNat]!}"] }
  else
    panic! s!"Unknown op {op}"

Example: Store Register Value in RAM

i
  else if op == OP_STR then
    some { s' with ram := s'.ram.set! (s'.reg[i1]!.toNat) s'.reg[i0]! }

Example: Add Two Registers

i
  else if op == OP_ADD then
    some { s' with reg := s'.reg.set! i0 (s'.reg[i0]! + s'.reg[i1]!) }

Example: Jump If

i
  else if op == OP_BEQ then
    if s'.reg[i0]! == 0 then some { s' with ip := i1 }
    else some s'

Running a Program

i
/-- Run the VM for up to `fuel` steps, returning the final state.
    Halts early when a hlt instruction is encountered.
    The fuel parameter prevents non-termination in tests and proofs. -/
def run (fuel : Nat) (s : VMState) : VMState :=
  match fuel with
  | 0     => s
  | n + 1 =>
    match s.step with
    | none    => s
    | some s' => run n s'
i
-- A hlt-only program leaves all registers at zero.
#guard
  match VMState.init #[0x000001] with
  | none   => false
  | some s => (run 10 s).reg[0]! == 0

-- ldc R0 42 ; hlt  →  R0 == 42
#guard
  match VMState.init #[0x2A0002, 0x000001] with
  | none   => false
  | some s => (run 10 s).reg[0]! == 42

-- ldc R0 3 ; ldc R1 4 ; add R0 R1 ; hlt  →  R0 == 7
#guard
  match VMState.init #[0x030002, 0x040102, 0x010006, 0x000001] with
  | none   => false
  | some s => (run 10 s).reg[0]! == 7

Assembly Code

i
/-- Op code and operand format for one instruction.
    Format codes: "--" no operands, "r-" one register,
    "rr" two registers, "rv" register and value. -/
structure OpInfo where
  code : UInt32
  fmt  : String

/-- The complete instruction set, indexed by mnemonic. -/
def OPS : List (String × OpInfo) := [
  ("hlt", { code := OP_HLT, fmt := "--" }),
  ("ldc", { code := OP_LDC, fmt := "rv" }),
  ("ldr", { code := OP_LDR, fmt := "rr" }),
  ("cpy", { code := OP_CPY, fmt := "rr" }),
  ("str", { code := OP_STR, fmt := "rr" }),
  ("add", { code := OP_ADD, fmt := "rr" }),
  ("sub", { code := OP_SUB, fmt := "rr" }),
  ("beq", { code := OP_BEQ, fmt := "rv" }),
  ("bne", { code := OP_BNE, fmt := "rv" }),
  ("prr", { code := OP_PRR, fmt := "r-" }),
  ("prm", { code := OP_PRM, fmt := "r-" }),
]

Parsing Helpers

i
/-- Parse "R0".."R3" into a register index, or return none. -/
def parseReg (token : String) : Option UInt32 :=
  if token.startsWith "R" then
    match (token.drop 1).toNat? with
    | some n => if n < NUM_REG then some n.toUInt32 else none
    | none   => none
  else none

/-- Parse a decimal literal or @label reference into a value, or return none. -/
def parseVal (token : String) (labels : List (String × Nat)) : Option UInt32 :=
  if token.startsWith "@" then
    let lbl := token.drop 1
    (labels.find? fun (k, _) => k == lbl).map fun (_, v) => v.toUInt32
  else
    token.toNat?.map (·.toUInt32)

/-- Pack a list of values into one word by shifting each new value in from the left.
    Mirrors Python's _combine: each successive call shifts the accumulator up by
    OP_SHIFT bits and ORs in the next argument. -/
def combine (args : List UInt32) : UInt32 :=
  args.foldl (fun acc a => (acc <<< OP_SHIFT) ||| a) 0

Finding Labels

i
/-- Scan cleaned source lines and record the address of each label.
    Labels end with ":" and do not occupy any instruction slot. -/
def findLabels (lines : List String) : List (String × Nat) :=
  let go acc line :=
    let (labels, loc) := acc
    if line.endsWith ":" then
      ((line.dropRight 1 |>.trim, loc) :: labels, loc)
    else
      (labels, loc + 1)
  (lines.foldl go ([], 0)).1

Compiling One Line

i
/-- Compile one instruction line into a packed word.
    Returns none if the mnemonic is unknown or operands are malformed. -/
def compileLine (line : String) (labels : List (String × Nat)) : Option UInt32 :=
  let tokens := (line.splitOn " ").filter (· != "")
  match tokens with
  | [] => none
  | mnem :: args =>
    match OPS.find? fun (name, _) => name == mnem with
    | none           => none
    | some (_, info) =>
      match info.fmt with
      | "--" => some (combine [info.code])
      | "r-" =>
        match args with
        | [r] => (parseReg r).map fun r0 => combine [r0, info.code]
        | _   => none
      | "rr" =>
        match args with
        | [r0s, r1s] =>
          match parseReg r0s, parseReg r1s with
          | some r0, some r1 => some (combine [r1, r0, info.code])
          | _,       _       => none
        | _ => none
      | "rv" =>
        match args with
        | [r0s, vs] =>
          match parseReg r0s, parseVal vs labels with
          | some r0, some v => some (combine [v, r0, info.code])
          | _,       _      => none
        | _ => none
      | _ => none

Assembling a Program

i
/-- Assemble a list of source lines into an array of instruction words.
    Strips blank lines and comments (lines beginning with #), resolves
    labels, and compiles each remaining line. Returns none if any line
    fails to compile. -/
def assemble (lines : List String) : Option (Array UInt32) := do
  let lines :=
    lines
    |>.map String.trim
    |>.filter (fun l => l != "" && !l.startsWith "#")
  let labels := findLabels lines
  let instrs := lines.filter (fun l => !l.endsWith ":")
  let words   instrs.mapM (fun l => compileLine l labels)
  return words.toArray
i
-- Single hlt instruction encodes to op code 1 in the low byte.
#guard assemble ["hlt"] == some #[0x000001]

-- ldc R0 42: code=2, reg=0, val=0x2A → (0x2A << 16) | (0 << 8) | 2
#guard assemble ["ldc R0 42", "hlt"] == some #[0x2A0002, 0x000001]

-- Comments and blank lines are stripped before compilation.
#guard assemble ["# load then stop", "", "ldc R0 1", "hlt"] == some #[0x010002, 0x000001]

End-to-End Example

ldc R0 0      -- R0 = loop index
ldc R1 3      -- R1 = loop limit
loop:
prr R0        -- print index
ldc R2 1      -- R2 = 1 (needed because add is register-to-register)
add R0 R2     -- R0 += 1
cpy R2 R1     -- R2 = limit (sub overwrites its first operand)
sub R2 R0     -- R2 = limit - index
bne R2 @loop  -- repeat while R2 != 0
hlt
i
-- Assemble and run the count-up program: print 0, 1, 2 then halt.
#guard
  let src := [
    "ldc R0 0",     -- R0 = loop index, starting at 0
    "ldc R1 3",     -- R1 = loop limit
    "loop:",
    "prr R0",       -- print current index
    "ldc R2 1",     -- R2 = 1  (register-to-register add requires a register)
    "add R0 R2",    -- R0 += 1
    "cpy R2 R1",    -- R2 = R1 (copy limit; sub destroys its first operand)
    "sub R2 R0",    -- R2 = limit - index
    "bne R2 @loop", -- loop while R2 != 0
    "hlt"
  ]
  match assemble src with
  | none      => false
  | some prog =>
    match VMState.init prog with
    | none   => false
    | some s => (run 100 s).output == ["0", "1", "2"]

Exercises

Trace by Hand (10 min)

Write out the values of R0, R1, and R2 after each instruction in the first two iterations of the count-up loop above. At which step does bne decide not to jump?

Swap Two Registers (20 min)

Write an assembly program that swaps the values of R1 and R2 without changing R0 or R3. You will need a temporary register. Encode the program as a #guard that checks the final register state.

Increment and Decrement (20 min)

Add OP_INC and OP_DEC constants to architecture.lean with codes 0xC and 0xD. Add matching branches to VMState.step in machine.lean (format "r-", one register). Add entries to OPS in assembler.lean. Rewrite the count-up example using inc instead of ldc R2 1; add R0 R2.

Disassembler (30 min)

Write a function disassemble (prog : Array UInt32) : List String that turns each packed word back into a mnemonic string such as "ldc R0 42". Since labels are not stored in the machine code, generate synthetic names L000, L001, … for any address that appears as the target of a beq or bne instruction.

Store and Load an Array (30 min)

Write an assembly program (as a Lean string list) that stores the values 0–3 into consecutive memory locations starting at address 20, then reads them back and prints each one. Verify with a #guard that output == ["0", "1", "2", "3"].

Call and Return (45 min)

Add a stack pointer register SP initialized to address 255. Add psh (push: write reg[arg0] to ram[SP] then decrement SP) and pop (increment SP then read ram[SP] into reg[arg0]). Using these instructions, write a subroutine that doubles the value in R0, called twice with different inputs, and verify with #guard.