import scala.collection.SortedSet case class Var(name: String, lb: Int, ub: Int) { def size = ub - lb + 1 override def toString = name } abstract class Bool { def isFalse = false def isTrue = false } object False extends Bool { override def isFalse = true override def toString = "False" } object True extends Bool { override def isTrue = true override def toString = "True" } case class P(x: Var, b: Int) extends Bool { override def isFalse = b < x.lb override def isTrue = b >= x.ub } case class Literal(p: Bool, neg: Boolean = false) { def isFalse = if (neg) p.isTrue else p.isFalse def isTrue = if (neg) p.isFalse else p.isTrue override def toString = if (neg) "-" + p else p.toString } case class Clause(lits: Seq[Literal] = Seq.empty) { def +: (lit: Literal) = Clause(lit +: lits) override def toString = lits.mkString("{", ", ", "}") } object OrderEncoding1 { def lb(a: Int, x: Var): Int = if (a > 0) a * x.lb else a * x.ub def ub(a: Int, x: Var): Int = if (a > 0) a * x.ub else a * x.lb def lb(wsum: Seq[(Int,Var)]): Int = wsum.map { case (a, x) => lb(a, x) }.sum def ub(wsum: Seq[(Int,Var)]): Int = wsum.map { case (a, x) => ub(a, x) }.sum def floorDiv(b: Int, a: Int) = { // math.floor(b.toDouble / a).toInt is slow if (a > 0) { if (b >= 0) b/a else (b-a+1)/a } else { if (b >= 0) (b-a-1)/a else b/a } } def ceilDiv(b: Int, a: Int) = { // math.ceil(b.toDouble / a).toInt is slow if (a > 0) { if (b >= 0) (b+a-1)/a else b/a } else { if (b >= 0) b/a else (b+a+1)/a } } def p(x: Var, b: Int): Bool = if (b < x.lb) False else if (b >= x.ub) True else P(x, b) def le(a: Int, x: Var, b: Int): Literal = if (a > 0) Literal(p(x, floorDiv(b, a)), false) else Literal(p(x, ceilDiv(b, a) - 1), true) def encodeLe(wsum: Seq[(Int,Var)], c: Int): Seq[Clause] = { if (wsum.size == 1) { val (a, x) = wsum.head val lit = le(a, x, c) if (lit.isFalse) Seq(Clause()) else if (lit.isTrue) Seq() else Seq(Clause(Seq(lit))) } else { val (a, x) = wsum.head if (a > 0) { val l = x.lb val u = math.min(x.ub, floorDiv(c-lb(wsum.tail), a)) for { b <- l to u + 1 val lit = Literal(p(x, b-1), false) if ! lit.isTrue clause <- encodeLe(wsum.tail, c-a*b) } yield if (lit.isFalse) clause else lit +: clause } else { val l = math.max(x.lb, ceilDiv(lb(wsum.tail)-c, -a)) val u = x.ub for { b <- l - 1 to u val lit = Literal(p(x, b), true) if ! lit.isTrue clause <- encodeLe(wsum.tail, c-a*b) } yield if (lit.isFalse) clause else lit +: clause } } } def example1 { val x = Var("x", 2, 6) val y = Var("y", 2, 6) val wsum = Seq((1,x), (1,y)) val clauses = encodeLe(wsum, 7) clauses.foreach(println) println(clauses.size + " clauses") } def example2 { val x = Var("x", 0, 3) val y = Var("y", 0, 3) val z = Var("z", 0, 3) val wsum = Seq((1,x), (1,y), (-1,z)) val clauses = encodeLe(wsum, -2) clauses.foreach(println) println(clauses.size + " clauses") } def example3 { val x = Var("x", 0, 5) val y = Var("y", 0, 3) val wsum = Seq((3,x), (5,y)) val clauses = encodeLe(wsum, 14) clauses.foreach(println) println(clauses.size + " clauses") } def example4 { val x = Var("x", 0, 5) val y = Var("y", 0, 3) val wsum = Seq((3,x), (5,y)).sortWith { case ((a,x), (b,y)) => x.size < y.size || (x.size == y.size && math.abs(a) > math.abs(b)) } val clauses = encodeLe(wsum, 14) clauses.foreach(println) println(clauses.size + " clauses") } def example5 { val w = Var("w", 0, 99) val x = Var("x", 0, 99) val y = Var("y", 0, 99) val z = Var("z", 0, 99) val clauses = encodeLe(Seq((1,w), (1,x), (1,y), (1,z)), 200) println(clauses.size + " clauses") } def example6 { val w = Var("w", 0, 99) val x = Var("x", 0, 99) val y = Var("y", 0, 99) val z = Var("z", 0, 99) val u = Var("u", 0, 198) val clauses1 = encodeLe(Seq((1,w), (1,x), (1,u)), 200) val clauses2 = encodeLe(Seq((1,y), (1,z), (-1,u)), 0) val clauses3 = encodeLe(Seq((-1,y), (-1,z), (1,u)), 0) val clauses = clauses1 ++ clauses2 ++ clauses3 println(clauses.size + " clauses") } }