diff --git a/cpp2rust/converter/converter.cpp b/cpp2rust/converter/converter.cpp index 359210aa..8f3b1d7e 100644 --- a/cpp2rust/converter/converter.cpp +++ b/cpp2rust/converter/converter.cpp @@ -990,6 +990,7 @@ bool Converter::VisitIfStmt(clang::IfStmt *stmt) { } bool Converter::VisitWhileStmt(clang::WhileStmt *stmt) { + PushBreakTarget push(break_target_, BreakTarget::Loop); StrCat("'loop_:"); StrCat(keyword::kWhile); ConvertCondition(stmt->getCond()); @@ -1002,6 +1003,7 @@ bool Converter::VisitWhileStmt(clang::WhileStmt *stmt) { } bool Converter::VisitDoStmt(clang::DoStmt *stmt) { + PushBreakTarget push(break_target_, BreakTarget::Loop); StrCat("'loop_:"); StrCat(keyword::kLoop, token::kOpenCurlyBracket); curr_for_inc_.emplace(nullptr); @@ -1016,6 +1018,7 @@ bool Converter::VisitDoStmt(clang::DoStmt *stmt) { } bool Converter::VisitForStmt(clang::ForStmt *stmt) { + PushBreakTarget push(break_target_, BreakTarget::Loop); Convert(stmt->getInit()); StrCat("'loop_:"); StrCat(keyword::kWhile); @@ -1055,6 +1058,7 @@ void Converter::ConvertLoopVariable(clang::VarDecl *decl, void Converter::ConvertForRangeBody(clang::CXXForRangeStmt *stmt, const clang::VarDecl *map_iter_decl) { + PushBreakTarget push(break_target_, BreakTarget::Loop); std::optional skip; if (map_iter_decl) skip.emplace(*this, map_iter_decl); @@ -1137,7 +1141,7 @@ bool Converter::VisitCXXForRangeStmtIndexBased(clang::CXXForRangeStmt *stmt, bool Converter::VisitBreakStmt([[maybe_unused]] clang::BreakStmt *stmt) { StrCat(keyword::kBreak); - if (break_with_explicit_label_) { + if (isSwitchBreak()) { StrCat("'switch"); } return false; @@ -2617,66 +2621,69 @@ bool Converter::VisitImplicitValueInitExpr(clang::ImplicitValueInitExpr *expr) { return false; } -static std::unordered_set visited_cases; +bool Converter::ConvertSwitchCaseCondition(clang::SwitchCase *stmt) { + clang::Stmt *cur = stmt; + clang::SwitchCase *last = nullptr; + bool first = true; -bool Converter::VisitSwitchCase(clang::SwitchCase *stmt) { - if (visited_cases.contains(stmt)) { - return false; + while (auto *sc = clang::dyn_cast(cur)) { + if (auto *case_stmt = clang::dyn_cast(sc)) { + if (!first) { + StrCat("|| v == "); + } + Convert(case_stmt->getLHS()); + } + last = sc; + first = false; + cur = sc->getSubStmt(); } - visited_cases.insert(stmt); - if (auto case_stmt = clang::dyn_cast(stmt)) { - Convert(case_stmt->getLHS()); + if (clang::isa(last)) { + StrCat(" => {"); + } else /* DefaultStmt */ { + StrCat("_ => {"); } + return false; +} - if (clang::isa(stmt->getSubStmt())) { - StrCat("|| v == "); +void Converter::EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc, + bool is_default) { + if (is_default) { + StrCat("_ => {"); } else { - if (clang::isa(stmt)) { - StrCat(" => {"); - } else { - StrCat("_ => {"); - } + StrCat("v if v == "); + ConvertSwitchCaseCondition(sc); } - - Convert(stmt->getSubStmt()); - return false; + for (auto *t : GetSwitchCaseBody(body, sc)) { + Convert(t); + } + StrCat("},"); } bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) { + PushBreakTarget push(break_target_, BreakTarget::Switch); + auto *body = clang::dyn_cast(stmt->getBody()); + assert(body); + StrCat("'switch: {"); StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond()))); StrCat("match __match_cond"); StrCat("{"); - bool has_default_case = false; - auto body = llvm::cast(stmt->getBody()); - assert(body); - - break_with_explicit_label_ = true; - for (auto it = body->body_begin(), end = body->body_end(); it != end;) { - if (auto switch_case = clang::dyn_cast(*it)) { - if (clang::isa(switch_case)) { - StrCat("v if v == "); - } else { - has_default_case = true; - } - VisitSwitchCase(switch_case); - ++it; - } - - while (it != end && !clang::isa(*it)) { - Convert(*it); - ++it; + clang::SwitchCase *default_case = nullptr; + for (auto *sc : GetTopLevelSwitchCases(stmt)) { + if (SwitchCaseContainsDefault(sc)) { + default_case = sc; + continue; } - - StrCat("},"); + EmitSwitchArm(body, sc, /*is_default=*/false); } - if (!has_default_case) { + if (default_case) { + EmitSwitchArm(body, default_case, /*is_default=*/true); + } else { StrCat(R"( _ => {})"); } - break_with_explicit_label_ = false; StrCat("}"); StrCat("}"); diff --git a/cpp2rust/converter/converter.h b/cpp2rust/converter/converter.h index 74e8e750..61143d92 100644 --- a/cpp2rust/converter/converter.h +++ b/cpp2rust/converter/converter.h @@ -290,7 +290,10 @@ class Converter : public clang::RecursiveASTVisitor { virtual bool VisitSwitchStmt(clang::SwitchStmt *stmt); - virtual bool VisitSwitchCase(clang::SwitchCase *stmt); + void EmitSwitchArm(clang::CompoundStmt *body, clang::SwitchCase *sc, + bool is_default); + + bool ConvertSwitchCaseCondition(clang::SwitchCase *stmt); virtual bool VisitVAArgExpr(clang::VAArgExpr *expr); @@ -463,10 +466,30 @@ class Converter : public clang::RecursiveASTVisitor { clang::ASTContext &ctx_; clang::FunctionDecl *curr_function_ = nullptr; bool in_function_formals_ = false; - bool break_with_explicit_label_ = false; std::stack curr_for_inc_; std::stack curr_init_type_; + enum class BreakTarget { Loop, Switch }; + std::stack break_target_; + + bool isSwitchBreak() const { + return !break_target_.empty() && break_target_.top() == BreakTarget::Switch; + } + + class PushBreakTarget { + public: + PushBreakTarget(std::stack &stack, BreakTarget target) + : stack_(stack) { + stack_.push(target); + } + ~PushBreakTarget() { stack_.pop(); } + PushBreakTarget(const PushBreakTarget &) = delete; + PushBreakTarget &operator=(const PushBreakTarget &) = delete; + + private: + std::stack &stack_; + }; + std::unordered_set map_iter_decls_; struct ScopedMapIterDecl { diff --git a/cpp2rust/converter/converter_lib.cpp b/cpp2rust/converter/converter_lib.cpp index d493ad74..d784906f 100644 --- a/cpp2rust/converter/converter_lib.cpp +++ b/cpp2rust/converter/converter_lib.cpp @@ -660,4 +660,56 @@ clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx) { /*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride()); } +std::vector +GetTopLevelSwitchCases(clang::SwitchStmt *stmt) { + std::vector cases; + if (auto *body = llvm::dyn_cast(stmt->getBody())) { + for (auto *s : body->body()) { + if (auto *sc = clang::dyn_cast(s)) { + cases.push_back(sc); + } + } + } + return cases; +} + +bool SwitchCaseContainsDefault(clang::SwitchCase *c) { + for (clang::Stmt *cur = c;;) { + if (clang::isa(cur)) { + return true; + } + auto *sc = clang::dyn_cast(cur); + if (!sc) { + return false; + } + cur = sc->getSubStmt(); + } + return false; +} + +static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) { + clang::Stmt *cur = c->getSubStmt(); + while (auto *sc = clang::dyn_cast(cur)) { + cur = sc->getSubStmt(); + } + return cur; +} + +std::vector GetSwitchCaseBody(clang::CompoundStmt *body, + clang::SwitchCase *head) { + std::vector out; + out.push_back(GetLastStmtOfSwitchCase(head)); + auto it = body->body_begin(), end = body->body_end(); + while (it != end && *it != head) { + ++it; + } + assert(it != end); + ++it; + while (it != end && !clang::isa(*it)) { + out.push_back(*it); + ++it; + } + return out; +} + } // namespace cpp2rust diff --git a/cpp2rust/converter/converter_lib.h b/cpp2rust/converter/converter_lib.h index 36776336..c10ca23a 100644 --- a/cpp2rust/converter/converter_lib.h +++ b/cpp2rust/converter/converter_lib.h @@ -154,4 +154,12 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt); clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx); +std::vector +GetTopLevelSwitchCases(clang::SwitchStmt *stmt); + +bool SwitchCaseContainsDefault(clang::SwitchCase *c); + +std::vector GetSwitchCaseBody(clang::CompoundStmt *body, + clang::SwitchCase *head); + } // namespace cpp2rust diff --git a/tests/unit/out/refcount/switch_default_first.rs b/tests/unit/out/refcount/switch_default_first.rs index 4904dd1c..d3591a5c 100644 --- a/tests/unit/out/refcount/switch_default_first.rs +++ b/tests/unit/out/refcount/switch_default_first.rs @@ -12,10 +12,6 @@ pub fn default_first_0(x: i32) -> i32 { 'switch: { let __match_cond = (*x.borrow()); match __match_cond { - _ => { - (*r.borrow_mut()) = 7; - break 'switch; - } v if v == 1 => { (*r.borrow_mut()) = 1; break 'switch; @@ -24,6 +20,10 @@ pub fn default_first_0(x: i32) -> i32 { (*r.borrow_mut()) = 2; break 'switch; } + _ => { + (*r.borrow_mut()) = 7; + break 'switch; + } } }; return (*r.borrow()); diff --git a/tests/unit/out/refcount/switch_default_middle.rs b/tests/unit/out/refcount/switch_default_middle.rs index e5296161..309a54d8 100644 --- a/tests/unit/out/refcount/switch_default_middle.rs +++ b/tests/unit/out/refcount/switch_default_middle.rs @@ -16,14 +16,14 @@ pub fn default_middle_0(x: i32) -> i32 { (*r.borrow_mut()) = 1; break 'switch; } - _ => { - (*r.borrow_mut()) = 99; - break 'switch; - } v if v == 2 => { (*r.borrow_mut()) = 2; break 'switch; } + _ => { + (*r.borrow_mut()) = 99; + break 'switch; + } } }; return (*r.borrow()); diff --git a/tests/unit/out/refcount/switch_for_in_switch_break.rs b/tests/unit/out/refcount/switch_for_in_switch_break.rs index 689034ca..444b1a17 100644 --- a/tests/unit/out/refcount/switch_for_in_switch_break.rs +++ b/tests/unit/out/refcount/switch_for_in_switch_break.rs @@ -16,7 +16,7 @@ pub fn for_in_switch_break_0(n: i32) -> i32 { let i: Value = Rc::new(RefCell::new(0)); 'loop_: while ((*i.borrow()) < 10) { if ((*i.borrow()) == 3) { - break 'switch; + break; } (*r.borrow_mut()) += (*i.borrow()); (*i.borrow_mut()).prefix_inc(); diff --git a/tests/unit/out/refcount/switch_for_switch_for_break.rs b/tests/unit/out/refcount/switch_for_switch_for_break.rs index 49a797da..8063b20a 100644 --- a/tests/unit/out/refcount/switch_for_switch_for_break.rs +++ b/tests/unit/out/refcount/switch_for_switch_for_break.rs @@ -18,7 +18,7 @@ pub fn for_switch_for_break_0(n: i32) -> i32 { let j: Value = Rc::new(RefCell::new(0)); 'loop_: while ((*j.borrow()) < 10) { if ((*j.borrow()) == 2) { - break 'switch; + break; } (*r.borrow_mut()) += 1; (*j.borrow_mut()).prefix_inc(); diff --git a/tests/unit/out/refcount/switch_nested.rs b/tests/unit/out/refcount/switch_nested.rs index d9ef00c2..3f98290d 100644 --- a/tests/unit/out/refcount/switch_nested.rs +++ b/tests/unit/out/refcount/switch_nested.rs @@ -32,15 +32,15 @@ pub fn nested_0(a: i32, b: i32) -> i32 { } }; (*r.borrow_mut()) += 1; - break; + break 'switch; } v if v == 2 => { (*r.borrow_mut()) = 2; - break; + break 'switch; } _ => { (*r.borrow_mut()) = -1_i32; - break; + break 'switch; } } }; diff --git a/tests/unit/out/refcount/switch_while_in_switch_break.rs b/tests/unit/out/refcount/switch_while_in_switch_break.rs index 0bc32f78..8cbb0698 100644 --- a/tests/unit/out/refcount/switch_while_in_switch_break.rs +++ b/tests/unit/out/refcount/switch_while_in_switch_break.rs @@ -16,7 +16,7 @@ pub fn while_in_switch_break_0(n: i32) -> i32 { let i: Value = Rc::new(RefCell::new(0)); 'loop_: while ((*i.borrow()) < 10) { if ((*i.borrow()) == 4) { - break 'switch; + break; } (*r.borrow_mut()) += (*i.borrow()); (*i.borrow_mut()).prefix_inc(); diff --git a/tests/unit/out/unsafe/switch_default_first.rs b/tests/unit/out/unsafe/switch_default_first.rs index 3f39c7e4..1be0755f 100644 --- a/tests/unit/out/unsafe/switch_default_first.rs +++ b/tests/unit/out/unsafe/switch_default_first.rs @@ -11,10 +11,6 @@ pub unsafe fn default_first_0(mut x: i32) -> i32 { 'switch: { let __match_cond = x; match __match_cond { - _ => { - r = 7; - break 'switch; - } v if v == 1 => { r = 1; break 'switch; @@ -23,6 +19,10 @@ pub unsafe fn default_first_0(mut x: i32) -> i32 { r = 2; break 'switch; } + _ => { + r = 7; + break 'switch; + } } }; return r; diff --git a/tests/unit/out/unsafe/switch_default_middle.rs b/tests/unit/out/unsafe/switch_default_middle.rs index c1e53498..ff44c598 100644 --- a/tests/unit/out/unsafe/switch_default_middle.rs +++ b/tests/unit/out/unsafe/switch_default_middle.rs @@ -15,14 +15,14 @@ pub unsafe fn default_middle_0(mut x: i32) -> i32 { r = 1; break 'switch; } - _ => { - r = 99; - break 'switch; - } v if v == 2 => { r = 2; break 'switch; } + _ => { + r = 99; + break 'switch; + } } }; return r; diff --git a/tests/unit/out/unsafe/switch_for_in_switch_break.rs b/tests/unit/out/unsafe/switch_for_in_switch_break.rs index 391a491c..41ef40a0 100644 --- a/tests/unit/out/unsafe/switch_for_in_switch_break.rs +++ b/tests/unit/out/unsafe/switch_for_in_switch_break.rs @@ -15,7 +15,7 @@ pub unsafe fn for_in_switch_break_0(mut n: i32) -> i32 { let mut i: i32 = 0; 'loop_: while ((i) < (10)) { if ((i) == (3)) { - break 'switch; + break; } r += i; i.prefix_inc(); diff --git a/tests/unit/out/unsafe/switch_for_switch_for_break.rs b/tests/unit/out/unsafe/switch_for_switch_for_break.rs index 4eb72969..65c304b9 100644 --- a/tests/unit/out/unsafe/switch_for_switch_for_break.rs +++ b/tests/unit/out/unsafe/switch_for_switch_for_break.rs @@ -17,7 +17,7 @@ pub unsafe fn for_switch_for_break_0(mut n: i32) -> i32 { let mut j: i32 = 0; 'loop_: while ((j) < (10)) { if ((j) == (2)) { - break 'switch; + break; } r += 1; j.prefix_inc(); diff --git a/tests/unit/out/unsafe/switch_nested.rs b/tests/unit/out/unsafe/switch_nested.rs index aeee5de3..5630db70 100644 --- a/tests/unit/out/unsafe/switch_nested.rs +++ b/tests/unit/out/unsafe/switch_nested.rs @@ -30,15 +30,15 @@ pub unsafe fn nested_0(mut a: i32, mut b: i32) -> i32 { } }; r += 1; - break; + break 'switch; } v if v == 2 => { r = 2; - break; + break 'switch; } _ => { r = -1_i32; - break; + break 'switch; } } }; diff --git a/tests/unit/out/unsafe/switch_while_in_switch_break.rs b/tests/unit/out/unsafe/switch_while_in_switch_break.rs index 42900cb9..2acb971c 100644 --- a/tests/unit/out/unsafe/switch_while_in_switch_break.rs +++ b/tests/unit/out/unsafe/switch_while_in_switch_break.rs @@ -15,7 +15,7 @@ pub unsafe fn while_in_switch_break_0(mut n: i32) -> i32 { let mut i: i32 = 0; 'loop_: while ((i) < (10)) { if ((i) == (4)) { - break 'switch; + break; } r += i; i.prefix_inc(); diff --git a/tests/unit/switch_default_first.cpp b/tests/unit/switch_default_first.cpp index 5ee6e3d6..7dec6d69 100644 --- a/tests/unit/switch_default_first.cpp +++ b/tests/unit/switch_default_first.cpp @@ -1,4 +1,3 @@ -// panic #include int default_first(int x) { diff --git a/tests/unit/switch_default_middle.cpp b/tests/unit/switch_default_middle.cpp index d187b47e..050440d0 100644 --- a/tests/unit/switch_default_middle.cpp +++ b/tests/unit/switch_default_middle.cpp @@ -1,4 +1,3 @@ -// panic #include int default_middle(int x) { diff --git a/tests/unit/switch_for_in_switch_break.cpp b/tests/unit/switch_for_in_switch_break.cpp index d15e0487..e7dda947 100644 --- a/tests/unit/switch_for_in_switch_break.cpp +++ b/tests/unit/switch_for_in_switch_break.cpp @@ -1,4 +1,3 @@ -// panic #include int for_in_switch_break(int n) { diff --git a/tests/unit/switch_for_switch_for_break.cpp b/tests/unit/switch_for_switch_for_break.cpp index 769b41ca..7a304a03 100644 --- a/tests/unit/switch_for_switch_for_break.cpp +++ b/tests/unit/switch_for_switch_for_break.cpp @@ -1,4 +1,3 @@ -// panic #include int for_switch_for_break(int n) { diff --git a/tests/unit/switch_nested.cpp b/tests/unit/switch_nested.cpp index 37f7f0bc..c7473597 100644 --- a/tests/unit/switch_nested.cpp +++ b/tests/unit/switch_nested.cpp @@ -1,4 +1,3 @@ -// no-compile #include int nested(int a, int b) { diff --git a/tests/unit/switch_while_in_switch_break.cpp b/tests/unit/switch_while_in_switch_break.cpp index 419f7ac5..897a7a90 100644 --- a/tests/unit/switch_while_in_switch_break.cpp +++ b/tests/unit/switch_while_in_switch_break.cpp @@ -1,4 +1,3 @@ -// panic #include int while_in_switch_break(int n) {