diff --git a/cpp2rust/converter/converter.cpp b/cpp2rust/converter/converter.cpp index cc8fe39c..da395937 100644 --- a/cpp2rust/converter/converter.cpp +++ b/cpp2rust/converter/converter.cpp @@ -2621,63 +2621,67 @@ bool Converter::VisitImplicitValueInitExpr(clang::ImplicitValueInitExpr *expr) { return false; } -static std::unordered_set visited_cases; - -bool Converter::VisitSwitchCase(clang::SwitchCase *stmt) { - if (visited_cases.contains(stmt)) { - return false; +bool Converter::ConvertSwitchCaseCondition(clang::SwitchCase *stmt) { + clang::Stmt *cur = stmt; + clang::SwitchCase *last = nullptr; + bool first = true; + + while (auto *sc = clang::dyn_cast(cur)) { + if (auto *case_stmt = clang::dyn_cast(sc)) { + if (!first) { + StrCat("|| v == "); + } + Convert(case_stmt->getLHS()); + } + last = sc; + first = false; + cur = sc->getSubStmt(); } - visited_cases.insert(stmt); - if (auto case_stmt = clang::dyn_cast(stmt)) { - Convert(case_stmt->getLHS()); + if (clang::isa(last)) { + StrCat(" => {"); + } else /* DefaultStmt */ { + StrCat("_ => {"); } + return false; +} - if (clang::isa(stmt->getSubStmt())) { - StrCat("|| v == "); +void Converter::EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc, + bool is_default) { + if (is_default) { + StrCat("_ => {"); } else { - if (clang::isa(stmt)) { - StrCat(" => {"); - } else { - StrCat("_ => {"); - } + StrCat("v if v == "); + ConvertSwitchCaseCondition(sc); } - - Convert(stmt->getSubStmt()); - return false; + for (auto *t : GetSwitchCaseBody(body, sc)) { + Convert(t); + } + StrCat("},"); } bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) { PushBreakTarget push(break_target_, BreakTarget::Switch); + auto *body = clang::dyn_cast(stmt->getBody()); + assert(body); + StrCat("'switch: {"); StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond()))); StrCat("match __match_cond"); StrCat("{"); - bool has_default_case = false; - auto body = llvm::cast(stmt->getBody()); - assert(body); - - for (auto it = body->body_begin(), end = body->body_end(); it != end;) { - if (auto switch_case = clang::dyn_cast(*it)) { - if (clang::isa(switch_case)) { - StrCat("v if v == "); - } else { - has_default_case = true; - } - VisitSwitchCase(switch_case); - ++it; - } - - while (it != end && !clang::isa(*it)) { - Convert(*it); - ++it; + clang::SwitchCase *default_case = nullptr; + for (auto *sc : GetTopLevelSwitchCases(stmt)) { + if (SwitchCaseContainsDefault(sc)) { + default_case = sc; + continue; } - - StrCat("},"); + EmitSwitchArm(body, sc, /*is_default=*/false); } - if (!has_default_case) { + if (default_case) { + EmitSwitchArm(body, default_case, /*is_default=*/true); + } else { StrCat(R"( _ => {})"); } diff --git a/cpp2rust/converter/converter.h b/cpp2rust/converter/converter.h index a0a14931..debb9cbe 100644 --- a/cpp2rust/converter/converter.h +++ b/cpp2rust/converter/converter.h @@ -290,7 +290,10 @@ class Converter : public clang::RecursiveASTVisitor { virtual bool VisitSwitchStmt(clang::SwitchStmt *stmt); - virtual bool VisitSwitchCase(clang::SwitchCase *stmt); + void EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc, + bool is_default); + + bool ConvertSwitchCaseCondition(clang::SwitchCase *stmt); virtual bool VisitVAArgExpr(clang::VAArgExpr *expr); diff --git a/cpp2rust/converter/converter_lib.cpp b/cpp2rust/converter/converter_lib.cpp index a7f12c8d..fe5f22a3 100644 --- a/cpp2rust/converter/converter_lib.cpp +++ b/cpp2rust/converter/converter_lib.cpp @@ -662,6 +662,58 @@ clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx) { /*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride()); } +std::vector +GetTopLevelSwitchCases(clang::SwitchStmt *stmt) { + std::vector cases; + if (auto *body = llvm::dyn_cast(stmt->getBody())) { + for (auto *s : body->body()) { + if (auto *sc = clang::dyn_cast(s)) { + cases.push_back(sc); + } + } + } + return cases; +} + +bool SwitchCaseContainsDefault(clang::SwitchCase *c) { + for (clang::Stmt *cur = c;;) { + if (clang::isa(cur)) { + return true; + } + auto *sc = clang::dyn_cast(cur); + if (!sc) { + return false; + } + cur = sc->getSubStmt(); + } + return false; +} + +static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) { + clang::Stmt *cur = c->getSubStmt(); + while (auto *sc = clang::dyn_cast(cur)) { + cur = sc->getSubStmt(); + } + return cur; +} + +std::vector GetSwitchCaseBody(clang::CompoundStmt *body, + clang::SwitchCase *head) { + std::vector out; + out.push_back(GetLastStmtOfSwitchCase(head)); + auto it = body->body_begin(), end = body->body_end(); + while (it != end && *it != head) { + ++it; + } + assert(it != end); + ++it; + while (it != end && !clang::isa(*it)) { + out.push_back(*it); + ++it; + } + return out; +} + static std::string_view Trim(std::string_view s) { auto is_space = [](unsigned char c) { return std::isspace(c); }; auto b = std::find_if_not(s.begin(), s.end(), is_space); diff --git a/cpp2rust/converter/converter_lib.h b/cpp2rust/converter/converter_lib.h index a5c8d692..515171e2 100644 --- a/cpp2rust/converter/converter_lib.h +++ b/cpp2rust/converter/converter_lib.h @@ -155,6 +155,14 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt); clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx); +std::vector +GetTopLevelSwitchCases(clang::SwitchStmt *stmt); + +bool SwitchCaseContainsDefault(clang::SwitchCase *c); + +std::vector GetSwitchCaseBody(clang::CompoundStmt *body, + clang::SwitchCase *head); + void Unwrap(std::string &s, std::string_view prefix, std::string_view suffix); } // namespace cpp2rust diff --git a/tests/unit/out/refcount/switch_case_then_default.rs b/tests/unit/out/refcount/switch_case_then_default.rs new file mode 100644 index 00000000..6c58a070 --- /dev/null +++ b/tests/unit/out/refcount/switch_case_then_default.rs @@ -0,0 +1,50 @@ +extern crate libcc2rs; +use libcc2rs::*; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::io::prelude::*; +use std::io::{Read, Seek, Write}; +use std::os::fd::AsFd; +use std::rc::{Rc, Weak}; +pub fn case_then_default_0(x: i32) -> i32 { + let x: Value = Rc::new(RefCell::new(x)); + let r: Value = Rc::new(RefCell::new(0)); + 'switch: { + let __match_cond = (*x.borrow()); + match __match_cond { + v if v == 2 => { + (*r.borrow_mut()) = 20; + break 'switch; + } + _ => { + (*r.borrow_mut()) = 10; + break 'switch; + } + } + }; + return (*r.borrow()); +} +pub fn main() { + std::process::exit(main_0()); +} +fn main_0() -> i32 { + assert!( + (({ + let _x: i32 = 1; + case_then_default_0(_x) + }) == 10) + ); + assert!( + (({ + let _x: i32 = 2; + case_then_default_0(_x) + }) == 20) + ); + assert!( + (({ + let _x: i32 = 99; + case_then_default_0(_x) + }) == 10) + ); + return 0; +} diff --git a/tests/unit/out/refcount/switch_cases_and_default_stacked.rs b/tests/unit/out/refcount/switch_cases_and_default_stacked.rs new file mode 100644 index 00000000..b7ca14b1 --- /dev/null +++ b/tests/unit/out/refcount/switch_cases_and_default_stacked.rs @@ -0,0 +1,56 @@ +extern crate libcc2rs; +use libcc2rs::*; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::io::prelude::*; +use std::io::{Read, Seek, Write}; +use std::os::fd::AsFd; +use std::rc::{Rc, Weak}; +pub fn cases_and_default_stacked_0(x: i32) -> i32 { + let x: Value = Rc::new(RefCell::new(x)); + let r: Value = Rc::new(RefCell::new(0)); + 'switch: { + let __match_cond = (*x.borrow()); + match __match_cond { + v if v == 3 => { + (*r.borrow_mut()) = 3; + break 'switch; + } + _ => { + (*r.borrow_mut()) = 42; + break 'switch; + } + } + }; + return (*r.borrow()); +} +pub fn main() { + std::process::exit(main_0()); +} +fn main_0() -> i32 { + assert!( + (({ + let _x: i32 = 1; + cases_and_default_stacked_0(_x) + }) == 42) + ); + assert!( + (({ + let _x: i32 = 2; + cases_and_default_stacked_0(_x) + }) == 42) + ); + assert!( + (({ + let _x: i32 = 3; + cases_and_default_stacked_0(_x) + }) == 3) + ); + assert!( + (({ + let _x: i32 = 99; + cases_and_default_stacked_0(_x) + }) == 42) + ); + return 0; +} diff --git a/tests/unit/out/refcount/switch_default_first.rs b/tests/unit/out/refcount/switch_default_first.rs index 4904dd1c..d3591a5c 100644 --- a/tests/unit/out/refcount/switch_default_first.rs +++ b/tests/unit/out/refcount/switch_default_first.rs @@ -12,10 +12,6 @@ pub fn default_first_0(x: i32) -> i32 { 'switch: { let __match_cond = (*x.borrow()); match __match_cond { - _ => { - (*r.borrow_mut()) = 7; - break 'switch; - } v if v == 1 => { (*r.borrow_mut()) = 1; break 'switch; @@ -24,6 +20,10 @@ pub fn default_first_0(x: i32) -> i32 { (*r.borrow_mut()) = 2; break 'switch; } + _ => { + (*r.borrow_mut()) = 7; + break 'switch; + } } }; return (*r.borrow()); diff --git a/tests/unit/out/refcount/switch_default_middle.rs b/tests/unit/out/refcount/switch_default_middle.rs index e5296161..309a54d8 100644 --- a/tests/unit/out/refcount/switch_default_middle.rs +++ b/tests/unit/out/refcount/switch_default_middle.rs @@ -16,14 +16,14 @@ pub fn default_middle_0(x: i32) -> i32 { (*r.borrow_mut()) = 1; break 'switch; } - _ => { - (*r.borrow_mut()) = 99; - break 'switch; - } v if v == 2 => { (*r.borrow_mut()) = 2; break 'switch; } + _ => { + (*r.borrow_mut()) = 99; + break 'switch; + } } }; return (*r.borrow()); diff --git a/tests/unit/out/refcount/switch_default_then_case.rs b/tests/unit/out/refcount/switch_default_then_case.rs new file mode 100644 index 00000000..abaacb07 --- /dev/null +++ b/tests/unit/out/refcount/switch_default_then_case.rs @@ -0,0 +1,60 @@ +extern crate libcc2rs; +use libcc2rs::*; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::io::prelude::*; +use std::io::{Read, Seek, Write}; +use std::os::fd::AsFd; +use std::rc::{Rc, Weak}; +pub fn default_then_case_0(x: i32) -> i32 { + let x: Value = Rc::new(RefCell::new(x)); + let r: Value = Rc::new(RefCell::new(0)); + 'switch: { + let __match_cond = (*x.borrow()); + match __match_cond { + v if v == 1 => { + (*r.borrow_mut()) = 1; + break 'switch; + } + v if v == 3 => { + (*r.borrow_mut()) = 3; + break 'switch; + } + _ => { + (*r.borrow_mut()) = 77; + break 'switch; + } + } + }; + return (*r.borrow()); +} +pub fn main() { + std::process::exit(main_0()); +} +fn main_0() -> i32 { + assert!( + (({ + let _x: i32 = 1; + default_then_case_0(_x) + }) == 1) + ); + assert!( + (({ + let _x: i32 = 2; + default_then_case_0(_x) + }) == 77) + ); + assert!( + (({ + let _x: i32 = 3; + default_then_case_0(_x) + }) == 3) + ); + assert!( + (({ + let _x: i32 = 99; + default_then_case_0(_x) + }) == 77) + ); + return 0; +} diff --git a/tests/unit/out/unsafe/switch_case_then_default.rs b/tests/unit/out/unsafe/switch_case_then_default.rs new file mode 100644 index 00000000..ee8d944c --- /dev/null +++ b/tests/unit/out/unsafe/switch_case_then_default.rs @@ -0,0 +1,51 @@ +extern crate libc; +use libc::*; +extern crate libcc2rs; +use libcc2rs::*; +use std::collections::BTreeMap; +use std::io::{Read, Seek, Write}; +use std::os::fd::{AsFd, FromRawFd, IntoRawFd}; +use std::rc::Rc; +pub unsafe fn case_then_default_0(mut x: i32) -> i32 { + let mut r: i32 = 0; + 'switch: { + let __match_cond = x; + match __match_cond { + v if v == 2 => { + r = 20; + break 'switch; + } + _ => { + r = 10; + break 'switch; + } + } + }; + return r; +} +pub fn main() { + unsafe { + std::process::exit(main_0() as i32); + } +} +unsafe fn main_0() -> i32 { + assert!( + ((unsafe { + let _x: i32 = 1; + case_then_default_0(_x) + }) == (10)) + ); + assert!( + ((unsafe { + let _x: i32 = 2; + case_then_default_0(_x) + }) == (20)) + ); + assert!( + ((unsafe { + let _x: i32 = 99; + case_then_default_0(_x) + }) == (10)) + ); + return 0; +} diff --git a/tests/unit/out/unsafe/switch_cases_and_default_stacked.rs b/tests/unit/out/unsafe/switch_cases_and_default_stacked.rs new file mode 100644 index 00000000..1541d1e4 --- /dev/null +++ b/tests/unit/out/unsafe/switch_cases_and_default_stacked.rs @@ -0,0 +1,57 @@ +extern crate libc; +use libc::*; +extern crate libcc2rs; +use libcc2rs::*; +use std::collections::BTreeMap; +use std::io::{Read, Seek, Write}; +use std::os::fd::{AsFd, FromRawFd, IntoRawFd}; +use std::rc::Rc; +pub unsafe fn cases_and_default_stacked_0(mut x: i32) -> i32 { + let mut r: i32 = 0; + 'switch: { + let __match_cond = x; + match __match_cond { + v if v == 3 => { + r = 3; + break 'switch; + } + _ => { + r = 42; + break 'switch; + } + } + }; + return r; +} +pub fn main() { + unsafe { + std::process::exit(main_0() as i32); + } +} +unsafe fn main_0() -> i32 { + assert!( + ((unsafe { + let _x: i32 = 1; + cases_and_default_stacked_0(_x) + }) == (42)) + ); + assert!( + ((unsafe { + let _x: i32 = 2; + cases_and_default_stacked_0(_x) + }) == (42)) + ); + assert!( + ((unsafe { + let _x: i32 = 3; + cases_and_default_stacked_0(_x) + }) == (3)) + ); + assert!( + ((unsafe { + let _x: i32 = 99; + cases_and_default_stacked_0(_x) + }) == (42)) + ); + return 0; +} diff --git a/tests/unit/out/unsafe/switch_default_first.rs b/tests/unit/out/unsafe/switch_default_first.rs index 3f39c7e4..1be0755f 100644 --- a/tests/unit/out/unsafe/switch_default_first.rs +++ b/tests/unit/out/unsafe/switch_default_first.rs @@ -11,10 +11,6 @@ pub unsafe fn default_first_0(mut x: i32) -> i32 { 'switch: { let __match_cond = x; match __match_cond { - _ => { - r = 7; - break 'switch; - } v if v == 1 => { r = 1; break 'switch; @@ -23,6 +19,10 @@ pub unsafe fn default_first_0(mut x: i32) -> i32 { r = 2; break 'switch; } + _ => { + r = 7; + break 'switch; + } } }; return r; diff --git a/tests/unit/out/unsafe/switch_default_middle.rs b/tests/unit/out/unsafe/switch_default_middle.rs index c1e53498..ff44c598 100644 --- a/tests/unit/out/unsafe/switch_default_middle.rs +++ b/tests/unit/out/unsafe/switch_default_middle.rs @@ -15,14 +15,14 @@ pub unsafe fn default_middle_0(mut x: i32) -> i32 { r = 1; break 'switch; } - _ => { - r = 99; - break 'switch; - } v if v == 2 => { r = 2; break 'switch; } + _ => { + r = 99; + break 'switch; + } } }; return r; diff --git a/tests/unit/out/unsafe/switch_default_then_case.rs b/tests/unit/out/unsafe/switch_default_then_case.rs new file mode 100644 index 00000000..f75dfb60 --- /dev/null +++ b/tests/unit/out/unsafe/switch_default_then_case.rs @@ -0,0 +1,61 @@ +extern crate libc; +use libc::*; +extern crate libcc2rs; +use libcc2rs::*; +use std::collections::BTreeMap; +use std::io::{Read, Seek, Write}; +use std::os::fd::{AsFd, FromRawFd, IntoRawFd}; +use std::rc::Rc; +pub unsafe fn default_then_case_0(mut x: i32) -> i32 { + let mut r: i32 = 0; + 'switch: { + let __match_cond = x; + match __match_cond { + v if v == 1 => { + r = 1; + break 'switch; + } + v if v == 3 => { + r = 3; + break 'switch; + } + _ => { + r = 77; + break 'switch; + } + } + }; + return r; +} +pub fn main() { + unsafe { + std::process::exit(main_0() as i32); + } +} +unsafe fn main_0() -> i32 { + assert!( + ((unsafe { + let _x: i32 = 1; + default_then_case_0(_x) + }) == (1)) + ); + assert!( + ((unsafe { + let _x: i32 = 2; + default_then_case_0(_x) + }) == (77)) + ); + assert!( + ((unsafe { + let _x: i32 = 3; + default_then_case_0(_x) + }) == (3)) + ); + assert!( + ((unsafe { + let _x: i32 = 99; + default_then_case_0(_x) + }) == (77)) + ); + return 0; +} diff --git a/tests/unit/switch_case_then_default.cpp b/tests/unit/switch_case_then_default.cpp index d41ea62c..49d1d2ca 100644 --- a/tests/unit/switch_case_then_default.cpp +++ b/tests/unit/switch_case_then_default.cpp @@ -1,4 +1,3 @@ -// translation-fail #include int case_then_default(int x) { diff --git a/tests/unit/switch_cases_and_default_stacked.cpp b/tests/unit/switch_cases_and_default_stacked.cpp index b1db9886..57f422ed 100644 --- a/tests/unit/switch_cases_and_default_stacked.cpp +++ b/tests/unit/switch_cases_and_default_stacked.cpp @@ -1,4 +1,3 @@ -// translation-fail #include int cases_and_default_stacked(int x) { diff --git a/tests/unit/switch_default_first.cpp b/tests/unit/switch_default_first.cpp index 5ee6e3d6..7dec6d69 100644 --- a/tests/unit/switch_default_first.cpp +++ b/tests/unit/switch_default_first.cpp @@ -1,4 +1,3 @@ -// panic #include int default_first(int x) { diff --git a/tests/unit/switch_default_middle.cpp b/tests/unit/switch_default_middle.cpp index d187b47e..050440d0 100644 --- a/tests/unit/switch_default_middle.cpp +++ b/tests/unit/switch_default_middle.cpp @@ -1,4 +1,3 @@ -// panic #include int default_middle(int x) { diff --git a/tests/unit/switch_default_then_case.cpp b/tests/unit/switch_default_then_case.cpp index 0b798ef5..eb39ef40 100644 --- a/tests/unit/switch_default_then_case.cpp +++ b/tests/unit/switch_default_then_case.cpp @@ -1,4 +1,3 @@ -// translation-fail #include int default_then_case(int x) {