Open Source Adventures: Episode 04: Automated Type Conversion for Crystal Z3

Our Crystal Z3 is progressing nicely and in this episode we want to reach the point where we can just say solver.assert 2 + Z3::IntSort[:x] == 4.

It's also time to organize the code into multiple files. I usually like coding this way, just throw everything into one file at first, and only take things out as they stabilize a bit.

Code structure

I'll need to turn this into a proper Crystal "shard" at some point, but for now let's use this structure:

As so much code moved around, I'll repeat all the code. I probably won't be doing this in future episodes, only focusing on what actually changed.

z3/libz3.cr

I added a few missing functions (for - and /):

module Z3
  @[Link("z3")]
  lib LibZ3
    type Ast = Void*
    type Config = Void*
    type Context = Void*
    type Model = Void*
    type Solver = Void*
    type Sort = Void*
    type Symbol = Void*

    enum LBool : Int32
      False = -1
      Undefined = 0
      True = 1
    end

    # Just list the ones we need, there's about 700 API calls total
    fun mk_add = Z3_mk_add(ctx : Context, count : UInt32, asts : Ast*) : Ast
    fun mk_bool_sort = Z3_mk_bool_sort(ctx : Context) : Sort
    fun mk_config = Z3_mk_config() : Config
    fun mk_const = Z3_mk_const(ctx : Context, name : Symbol, sort : Sort) : Ast
    fun mk_context = Z3_mk_context(cfg : Config) : Context
    fun mk_distinct = Z3_mk_distinct(ctx : Context, count : UInt32, asts : Ast*) : Ast
    fun mk_div = Z3_mk_div(ctx : Context, a : Ast, b : Ast) : Ast
    fun mk_eq = Z3_mk_eq(ctx : Context, a : Ast, b : Ast) : Ast
    fun mk_ge = Z3_mk_ge(ctx : Context, a : Ast, b : Ast) : Ast
    fun mk_gt = Z3_mk_gt(ctx : Context, a : Ast, b : Ast) : Ast
    fun mk_int_sort = Z3_mk_int_sort(ctx : Context) : Sort
    fun mk_le = Z3_mk_le(ctx : Context, a : Ast, b : Ast) : Ast
    fun mk_lt = Z3_mk_lt(ctx : Context, a : Ast, b : Ast) : Ast
    fun mk_mul = Z3_mk_mul(ctx : Context, count : UInt32, asts : Ast*) : Ast
    fun mk_numeral = Z3_mk_numeral(ctx : Context, s : UInt8*, sort : Sort) : Ast
    fun mk_solver = Z3_mk_solver(ctx : Context) : Solver
    fun mk_string_symbol = Z3_mk_string_symbol(ctx : Context, s : UInt8*) : Symbol
    fun mk_sub = Z3_mk_add(ctx : Context, count : UInt32, asts : Ast*) : Ast
    fun model_to_string = Z3_model_to_string(ctx : Context, model : Model) : UInt8*
    fun solver_assert = Z3_solver_assert(ctx : Context, solver : Solver, ast : Ast) : Void
    fun solver_check = Z3_solver_check(ctx : Context, solver : Solver) : LBool
    fun solver_get_model = Z3_solver_get_model(ctx : Context, solver : Solver) : Model
  end
end

z3/api.cr

I'm not really sure about this naming. It just got missing functions (for - and /).

This file is getting big, and at some point I'd like to reduce all that copypasta code, but we have more important things to deal with first.

module Z3
  module API
    extend self

    Context = LibZ3.mk_context(LibZ3.mk_config)

    def mk_solver
      LibZ3.mk_solver(Context)
    end

    def mk_numeral(num, sort)
      LibZ3.mk_numeral(Context, num.to_s, sort)
    end

    def mk_const(name, sort)
      name_sym = LibZ3.mk_string_symbol(Context, name)
      var = LibZ3.mk_const(Context, name_sym, sort)
    end

    def mk_eq(a, b)
      LibZ3.mk_eq(Context, a, b)
    end

    def mk_ge(a, b)
      LibZ3.mk_ge(Context, a, b)
    end

    def mk_gt(a, b)
      LibZ3.mk_gt(Context, a, b)
    end

    def mk_le(a, b)
      LibZ3.mk_le(Context, a, b)
    end

    def mk_lt(a, b)
      LibZ3.mk_lt(Context, a, b)
    end

    def mk_div(a, b)
      LibZ3.mk_div(Context, a, b)
    end

    def mk_add(asts)
      LibZ3.mk_add(Context, asts.size, asts)
    end

    def mk_mul(asts)
      LibZ3.mk_mul(Context, asts.size, asts)
    end

    def mk_sub(asts)
      LibZ3.mk_add(Context, asts.size, asts)
    end

    def mk_distinct(asts)
      LibZ3.mk_distinct(Context, asts.size, asts)
    end

    def solver_assert(solver, ast)
      LibZ3.solver_assert(Context, solver, ast)
    end

    def solver_check(solver)
      LibZ3.solver_check(Context, solver)
    end

    def solver_get_model(solver)
      LibZ3.solver_get_model(Context, solver)
    end

    def model_to_string(model)
      String.new LibZ3.model_to_string(Context, model)
    end
  end
end

z3/solver.cr

Solver now does the model checks automatically, caches the model, and throws an exception (no special exception class) if you try to access it without checking the model, and model turns out to be unsatisfiable.

module Z3
  class Solver
    def initialize
      @solver = API.mk_solver
      @model = nil
      @check = nil
    end

    def assert(expr)
      @model = nil
      @check = nil
      API.solver_assert(@solver, expr._expr)
    end

    def check
      @check = API.solver_check(@solver)
    end

    def model
      @model ||= begin
        check unless @check
        raise "Model not satisfiable" unless @check == LibZ3::LBool::True
        Model.new(API.solver_get_model(@solver))
      end
    end
  end
end

z3/model.cr

This class will need a lot of work next, to extract data from the model.

module Z3
  class Model
    def initialize(model : LibZ3::Model)
      @model = model
    end

    # This needs to go eventually
    def to_s(io)
      io << API.model_to_string(@model).chomp
    end
  end
end

z3/int_sort.cr

It got very interesting method self.[](expr : IntExpr), so we can now do Z3::IntSort[a] + Z3::IntSort[b] and it will return an Z3::IntExpr for a lot of different as and bs:

module Z3
  class IntSort
    @@sort = LibZ3.mk_int_sort(API::Context)

    def self.[](expr : IntExpr)
      expr
    end

    def self.[](name : Symbol)
      IntExpr.new API.mk_const(name.to_s, @@sort)
    end

    def self.[](v : Int)
      IntExpr.new API.mk_numeral(v, @@sort)
    end
  end
end

z3/bool_sort.cr

The whole BoolSort is a placeholder, but it will need similar treatment to what IntSort got.

module Z3
  class BoolSort
    @@sort = LibZ3.mk_bool_sort(API::Context)
  end
end

z3/int_expr.cr

The big change is that other is just whatever sort[...] accepts, so you can + Int32s, BigInts, or whatnot.

You can also pass Symbols. I'm not completely convinced if it's a a good idea or not (Ruby API doesn't do that), I'll leave it for now, maybe that's useful somehow?

A small change is adding - and /. At some point I'll go through the whole list and add all the remaining operators.

Another small change is marking _expr as protected. protected is this weird concept which apparently almost every OOP language has, but it always means something completely unrelated. C++ protected, Java protected, Ruby protected, Crystal protected, not even remotely close.

module Z3
  class IntExpr
    def initialize(expr : LibZ3::Ast)
      @expr = expr
    end

    def sort
      IntSort
    end

    def ==(other)
      BoolExpr.new API.mk_eq(@expr, sort[other]._expr)
    end

    def >=(other)
      BoolExpr.new API.mk_ge(@expr, sort[other]._expr)
    end

    def >(other)
      BoolExpr.new API.mk_gt(@expr, sort[other]._expr)
    end

    def <=(other)
      BoolExpr.new API.mk_le(@expr, sort[other]._expr)
    end

    def <(other)
      BoolExpr.new API.mk_lt(@expr, sort[other]._expr)
    end

    def *(other)
      IntExpr.new API.mk_mul([@expr, sort[other]._expr])
    end

    def +(other)
      IntExpr.new API.mk_add([@expr, sort[other]._expr])
    end

    def -(other)
      IntExpr.new API.mk_sub([@expr, sort[other]._expr])
    end

    def /(other)
      IntExpr.new API.mk_div(@expr, sort[other]._expr)
    end

    protected def _expr
      @expr
    end
  end
end

z3/bool_expr.cr

It's basically a placeholder:

module Z3
  class BoolExpr
    def initialize(expr : LibZ3::Ast)
      @expr = expr
    end

    def sort
      BoolSort
    end

    protected def _expr
      @expr
    end
  end
end

z3/core_ext.cr

I was worried about Crystal equivalent of Ruby #coerce, but it went super smoothly:

abstract struct Int
  def +(other : Z3::IntExpr)
    Z3::IntSort[self] + other
  end

  def *(other : Z3::IntExpr)
    Z3::IntSort[self] * other
  end

  def /(other : Z3::IntExpr)
    Z3::IntSort[self] / other
  end

  def -(other : Z3::IntExpr)
    Z3::IntSort[self] - other
  end

  def ==(other : Z3::IntExpr)
    Z3::IntSort[self] == other
  end

  def >=(other : Z3::IntExpr)
    Z3::IntSort[self] >= other
  end

  def >(other : Z3::IntExpr)
    Z3::IntSort[self] > other
  end

  def <=(other : Z3::IntExpr)
    Z3::IntSort[self] <= other
  end

  def <(other : Z3::IntExpr)
    Z3::IntSort[self] < other
  end
end

z3.cr

Main class for our library. I really love wildcard require, it's almost as good as Rails-style autoloading.

require "./z3/*"

module Z3
  def Z3.distinct(args : Array(IntExpr))
    BoolExpr.new API.mk_distinct(args.map(&._expr))
  end
end

send_more_money.cr

I messed with the syntax a bit on purpose, to have some numbers on the left, and some on the right:

#!/usr/bin/env crystal

require "./z3"

# Setup library
solver = Z3::Solver.new

# Variables, all 0 to 9
vars = Hash(Symbol, Z3::IntExpr).new
%i[s e n d m o r e m o n e y].uniq.each do |name|
  var = Z3::IntSort[name]
  vars[name] = var
  solver.assert 0 <= var
  solver.assert 9 >= var
end

# m and s need to be >= 1, no leading zeroes
solver.assert vars[:m] >= 1
solver.assert vars[:s] >= 1

# all letters represent different digits
solver.assert Z3.distinct(vars.values)

# SEND + MORE = MONEY
send_sum = (
  vars[:s] * 1000 +
  vars[:e] * 100 +
  vars[:n] * 10 +
  vars[:d]
)

more_sum = (
  vars[:m] * 1000 +
  vars[:o] * 100 +
  vars[:r] * 10 +
  vars[:e]
)

money_sum = (
  10000 * vars[:m] +
  1000 * vars[:o] +
  100 * vars[:n]+
  10 * vars[:e]+
  vars[:y]
)

solver.assert send_sum + more_sum == money_sum

# Get the result
puts solver.model

Story so far

All the code is in crystal-z3 repo.

Everything in this episode went amazingly smoothly. Not even a tiny issue.

Coming next

In the next episode we'll see how to extract data from the model, and maybe we can do a Sudoku solver.