diff --git a/cpp2rust/converter/converter.cpp b/cpp2rust/converter/converter.cpp index 359210aa..14be77f5 100644 --- a/cpp2rust/converter/converter.cpp +++ b/cpp2rust/converter/converter.cpp @@ -990,6 +990,7 @@ bool Converter::VisitIfStmt(clang::IfStmt *stmt) { } bool Converter::VisitWhileStmt(clang::WhileStmt *stmt) { + PushBreakTarget push(break_target_stack_, BreakTarget::Loop); StrCat("'loop_:"); StrCat(keyword::kWhile); ConvertCondition(stmt->getCond()); @@ -1002,6 +1003,7 @@ bool Converter::VisitWhileStmt(clang::WhileStmt *stmt) { } bool Converter::VisitDoStmt(clang::DoStmt *stmt) { + PushBreakTarget push(break_target_stack_, BreakTarget::Loop); StrCat("'loop_:"); StrCat(keyword::kLoop, token::kOpenCurlyBracket); curr_for_inc_.emplace(nullptr); @@ -1016,6 +1018,7 @@ bool Converter::VisitDoStmt(clang::DoStmt *stmt) { } bool Converter::VisitForStmt(clang::ForStmt *stmt) { + PushBreakTarget push(break_target_stack_, BreakTarget::Loop); Convert(stmt->getInit()); StrCat("'loop_:"); StrCat(keyword::kWhile); @@ -1055,6 +1058,7 @@ void Converter::ConvertLoopVariable(clang::VarDecl *decl, void Converter::ConvertForRangeBody(clang::CXXForRangeStmt *stmt, const clang::VarDecl *map_iter_decl) { + PushBreakTarget push(break_target_stack_, BreakTarget::Loop); std::optional skip; if (map_iter_decl) skip.emplace(*this, map_iter_decl); @@ -1136,10 +1140,12 @@ bool Converter::VisitCXXForRangeStmtIndexBased(clang::CXXForRangeStmt *stmt, } bool Converter::VisitBreakStmt([[maybe_unused]] clang::BreakStmt *stmt) { - StrCat(keyword::kBreak); - if (break_with_explicit_label_) { + if (break_target_stack_.isRegularSwitch()) { + StrCat(keyword::kBreak); StrCat("'switch"); + return false; } + StrCat(keyword::kBreak); return false; } @@ -2617,69 +2623,80 @@ bool Converter::VisitImplicitValueInitExpr(clang::ImplicitValueInitExpr *expr) { return false; } -static std::unordered_set visited_cases; - bool Converter::VisitSwitchCase(clang::SwitchCase *stmt) { - if (visited_cases.contains(stmt)) { - return false; - } - visited_cases.insert(stmt); - - if (auto case_stmt = clang::dyn_cast(stmt)) { - Convert(case_stmt->getLHS()); - } - - if (clang::isa(stmt->getSubStmt())) { - StrCat("|| v == "); - } else { - if (clang::isa(stmt)) { - StrCat(" => {"); - } else { - StrCat("_ => {"); + clang::Stmt *cur = stmt; + clang::SwitchCase *last = nullptr; + bool first = true; + + while (auto *sc = clang::dyn_cast(cur)) { + if (auto *case_stmt = clang::dyn_cast(sc)) { + if (!first) { + StrCat("|| v == "); + } + Convert(case_stmt->getLHS()); } + last = sc; + first = false; + cur = sc->getSubStmt(); } - Convert(stmt->getSubStmt()); + if (clang::isa(last)) { + StrCat(" => {"); + } else /* DefaultStmt */ { + StrCat("_ => {"); + } return false; } bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) { - StrCat("'switch: {"); - StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond()))); - StrCat("match __match_cond"); - StrCat("{"); - - bool has_default_case = false; - auto body = llvm::cast(stmt->getBody()); + auto *body = clang::dyn_cast(stmt->getBody()); assert(body); - break_with_explicit_label_ = true; - for (auto it = body->body_begin(), end = body->body_end(); it != end;) { - if (auto switch_case = clang::dyn_cast(*it)) { - if (clang::isa(switch_case)) { - StrCat("v if v == "); - } else { - has_default_case = true; - } - VisitSwitchCase(switch_case); - ++it; - } + bool has_fallthrough = SwitchHasFallthrough(stmt); - while (it != end && !clang::isa(*it)) { - Convert(*it); - ++it; - } + if (has_fallthrough) { + StrCat("switch!(match ", ToString(stmt->getCond()), " {"); + } else { + StrCat("'switch: {"); + StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond()))); + StrCat("match __match_cond"); + StrCat("{"); + } + + PushBreakTarget push(break_target_stack_, has_fallthrough + ? BreakTarget::FallthroughSwitch + : BreakTarget::RegularSwitch); + clang::SwitchCase *default_case = nullptr; + for (auto *sc : GetTopLevelSwitchCases(stmt)) { + if (SwitchCaseContainsDefault(sc)) { + default_case = sc; + continue; + } + StrCat("v if v == "); + VisitSwitchCase(sc); + for (auto *t : GetSwitchCaseBody(body, sc)) { + Convert(t); + } StrCat("},"); } - if (!has_default_case) { + if (default_case) { + StrCat("_ => {"); + for (auto *t : GetSwitchCaseBody(body, default_case)) { + Convert(t); + } + StrCat("},"); + } else { StrCat(R"( _ => {})"); } - break_with_explicit_label_ = false; - StrCat("}"); - StrCat("}"); + if (has_fallthrough) { + StrCat("})"); + } else { + StrCat("}"); + StrCat("}"); + } return false; } diff --git a/cpp2rust/converter/converter.h b/cpp2rust/converter/converter.h index 74e8e750..9e3a31c2 100644 --- a/cpp2rust/converter/converter.h +++ b/cpp2rust/converter/converter.h @@ -463,7 +463,35 @@ class Converter : public clang::RecursiveASTVisitor { clang::ASTContext &ctx_; clang::FunctionDecl *curr_function_ = nullptr; bool in_function_formals_ = false; - bool break_with_explicit_label_ = false; + + enum class BreakTarget { Loop, RegularSwitch, FallthroughSwitch }; + class BreakTargetStack { + public: + void push(BreakTarget t) { stack_.push(t); } + void pop() { stack_.pop(); } + bool isRegularSwitch() const { + return !stack_.empty() && stack_.top() == BreakTarget::RegularSwitch; + } + + private: + std::stack stack_; + }; + BreakTargetStack break_target_stack_; + + class PushBreakTarget { + public: + PushBreakTarget(BreakTargetStack &stack, BreakTarget target) + : stack_(stack) { + stack_.push(target); + } + ~PushBreakTarget() { stack_.pop(); } + PushBreakTarget(const PushBreakTarget &) = delete; + PushBreakTarget &operator=(const PushBreakTarget &) = delete; + + private: + BreakTargetStack &stack_; + }; + std::stack curr_for_inc_; std::stack curr_init_type_; diff --git a/cpp2rust/converter/converter_lib.cpp b/cpp2rust/converter/converter_lib.cpp index d493ad74..efcc0c4a 100644 --- a/cpp2rust/converter/converter_lib.cpp +++ b/cpp2rust/converter/converter_lib.cpp @@ -660,4 +660,78 @@ clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx) { /*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride()); } +std::vector +GetTopLevelSwitchCases(clang::SwitchStmt *stmt) { + std::vector cases; + if (auto *body = llvm::dyn_cast(stmt->getBody())) { + for (auto *s : body->body()) { + if (auto *sc = clang::dyn_cast(s)) { + cases.push_back(sc); + } + } + } + return cases; +} + +bool SwitchCaseContainsDefault(clang::SwitchCase *c) { + for (clang::Stmt *cur = c;;) { + if (clang::isa(cur)) { + return true; + } + auto *sc = clang::dyn_cast(cur); + if (!sc) { + return false; + } + cur = sc->getSubStmt(); + } + return false; +} + +static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) { + clang::Stmt *cur = c->getSubStmt(); + while (auto *sc = clang::dyn_cast(cur)) { + cur = sc->getSubStmt(); + } + return cur; +} + +std::vector GetSwitchCaseBody(clang::CompoundStmt *body, + clang::SwitchCase *head) { + std::vector out; + out.push_back(GetLastStmtOfSwitchCase(head)); + auto it = body->body_begin(), end = body->body_end(); + while (it != end && *it != head) { + ++it; + } + assert(it != end); + ++it; + while (it != end && !clang::isa(*it)) { + out.push_back(*it); + ++it; + } + return out; +} + +static bool SwitchCaseHasFallthrough(clang::Stmt *stmt) { + if (stmt) { + if (clang::isa(stmt) || + clang::isa(stmt)) { + return false; + } + } + return true; +} + +bool SwitchHasFallthrough(clang::SwitchStmt *stmt) { + if (auto *body = clang::dyn_cast(stmt->getBody())) { + for (auto top_level_case : GetTopLevelSwitchCases(stmt)) { + auto arm = GetSwitchCaseBody(body, top_level_case); + if (arm.empty() || SwitchCaseHasFallthrough(arm.back())) { + return true; + } + } + } + return false; +} + } // namespace cpp2rust diff --git a/cpp2rust/converter/converter_lib.h b/cpp2rust/converter/converter_lib.h index 36776336..91a3d44a 100644 --- a/cpp2rust/converter/converter_lib.h +++ b/cpp2rust/converter/converter_lib.h @@ -154,4 +154,14 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt); clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx); +std::vector +GetTopLevelSwitchCases(clang::SwitchStmt *stmt); + +bool SwitchCaseContainsDefault(clang::SwitchCase *c); + +std::vector GetSwitchCaseBody(clang::CompoundStmt *body, + clang::SwitchCase *head); + +bool SwitchHasFallthrough(clang::SwitchStmt *stmt); + } // namespace cpp2rust diff --git a/libcc2rs-macros/Cargo.toml b/libcc2rs-macros/Cargo.toml new file mode 100644 index 00000000..5a895659 --- /dev/null +++ b/libcc2rs-macros/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "libcc2rs-macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1" +quote = "1" +syn = { version = "2", features = ["full", "visit-mut", "extra-traits"] } + +[dev-dependencies] +trybuild = "1" diff --git a/libcc2rs-macros/src/goto.rs b/libcc2rs-macros/src/goto.rs new file mode 100644 index 00000000..916fd211 --- /dev/null +++ b/libcc2rs-macros/src/goto.rs @@ -0,0 +1,48 @@ +// Copyright (c) 2022-present INESC-ID. +// Distributed under the MIT license that can be found in the LICENSE file. + +use proc_macro::TokenStream; +use syn::parse::{Parse, ParseStream}; +use syn::{parse_macro_input, Expr, Lifetime, Token}; + +use crate::state_machine::{Arm, GotoStateMachine, StateMachine}; + +pub fn expand(input: TokenStream) -> TokenStream { + let GotoBlockInput { arms } = parse_macro_input!(input as GotoBlockInput); + GotoStateMachine { + arms: arms + .into_iter() + .map(|a| Arm { + label: a.label.ident.to_string(), + body: a.body, + }) + .collect(), + } + .emit() + .into() +} + +struct GotoBlockInput { + arms: Vec, +} + +struct GotoArm { + label: Lifetime, + body: Expr, +} + +impl Parse for GotoBlockInput { + fn parse(input: ParseStream) -> syn::Result { + let mut arms = Vec::new(); + while !input.is_empty() { + let label: Lifetime = input.parse()?; + input.parse::]>()?; + let body: Expr = input.parse()?; + arms.push(GotoArm { label, body }); + if input.peek(Token![,]) { + input.parse::()?; + } + } + Ok(Self { arms }) + } +} diff --git a/libcc2rs-macros/src/lib.rs b/libcc2rs-macros/src/lib.rs new file mode 100644 index 00000000..10317cf4 --- /dev/null +++ b/libcc2rs-macros/src/lib.rs @@ -0,0 +1,81 @@ +// Copyright (c) 2022-present INESC-ID. +// Distributed under the MIT license that can be found in the LICENSE file. + +use proc_macro::TokenStream; + +mod goto; +mod state_machine; +mod switch; + +// switch!(match { +// [if ] => { /* body; may contain break or continue */ }, +// ... +// _ => , +// }); +// +// Desugars to a goto_block! with a synthetic dispatch arm prepended. +// +// goto_block! { +// '__dispatch => { +// match { +// => { __s = 1; continue '__sm; } +// ... +// _ => break '__sm, +// } +// }, +// '__c1 => { /* body_1 with `break` rewritten to `break '__sm` */ }, +// ... +// '__cN => { /* body_N with same rewrite */ }, +// }; +// +// __sm is the inner label used to describe the state machine insinde goto_block. See goto_block! +// for more info. + +#[proc_macro] +pub fn switch(input: TokenStream) -> TokenStream { + switch::expand(input) +} + +// goto_block! { +// '