Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 64 additions & 47 deletions cpp2rust/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ bool Converter::VisitIfStmt(clang::IfStmt *stmt) {
}

bool Converter::VisitWhileStmt(clang::WhileStmt *stmt) {
PushBreakTarget push(break_target_stack_, BreakTarget::Loop);
StrCat("'loop_:");
StrCat(keyword::kWhile);
ConvertCondition(stmt->getCond());
Expand All @@ -1002,6 +1003,7 @@ bool Converter::VisitWhileStmt(clang::WhileStmt *stmt) {
}

bool Converter::VisitDoStmt(clang::DoStmt *stmt) {
PushBreakTarget push(break_target_stack_, BreakTarget::Loop);
StrCat("'loop_:");
StrCat(keyword::kLoop, token::kOpenCurlyBracket);
curr_for_inc_.emplace(nullptr);
Expand All @@ -1016,6 +1018,7 @@ bool Converter::VisitDoStmt(clang::DoStmt *stmt) {
}

bool Converter::VisitForStmt(clang::ForStmt *stmt) {
PushBreakTarget push(break_target_stack_, BreakTarget::Loop);
Convert(stmt->getInit());
StrCat("'loop_:");
StrCat(keyword::kWhile);
Expand Down Expand Up @@ -1055,6 +1058,7 @@ void Converter::ConvertLoopVariable(clang::VarDecl *decl,

void Converter::ConvertForRangeBody(clang::CXXForRangeStmt *stmt,
const clang::VarDecl *map_iter_decl) {
PushBreakTarget push(break_target_stack_, BreakTarget::Loop);
std::optional<ScopedMapIterDecl> skip;
if (map_iter_decl)
skip.emplace(*this, map_iter_decl);
Expand Down Expand Up @@ -1136,10 +1140,12 @@ bool Converter::VisitCXXForRangeStmtIndexBased(clang::CXXForRangeStmt *stmt,
}

bool Converter::VisitBreakStmt([[maybe_unused]] clang::BreakStmt *stmt) {
StrCat(keyword::kBreak);
if (break_with_explicit_label_) {
if (break_target_stack_.isRegularSwitch()) {
StrCat(keyword::kBreak);
StrCat("'switch");
return false;
}
StrCat(keyword::kBreak);
return false;
}

Expand Down Expand Up @@ -2617,69 +2623,80 @@ bool Converter::VisitImplicitValueInitExpr(clang::ImplicitValueInitExpr *expr) {
return false;
}

static std::unordered_set<clang::SwitchCase *> visited_cases;

bool Converter::VisitSwitchCase(clang::SwitchCase *stmt) {
if (visited_cases.contains(stmt)) {
return false;
}
visited_cases.insert(stmt);

if (auto case_stmt = clang::dyn_cast<clang::CaseStmt>(stmt)) {
Convert(case_stmt->getLHS());
}

if (clang::isa<clang::CaseStmt>(stmt->getSubStmt())) {
StrCat("|| v == ");
} else {
if (clang::isa<clang::CaseStmt>(stmt)) {
StrCat(" => {");
} else {
StrCat("_ => {");
clang::Stmt *cur = stmt;
clang::SwitchCase *last = nullptr;
bool first = true;

while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
if (auto *case_stmt = clang::dyn_cast<clang::CaseStmt>(sc)) {
if (!first) {
StrCat("|| v == ");
}
Convert(case_stmt->getLHS());
}
last = sc;
first = false;
cur = sc->getSubStmt();
}

Convert(stmt->getSubStmt());
if (clang::isa<clang::CaseStmt>(last)) {
StrCat(" => {");
} else /* DefaultStmt */ {
StrCat("_ => {");
}
return false;
}

bool Converter::VisitSwitchStmt(clang::SwitchStmt *stmt) {
StrCat("'switch: {");
StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond())));
StrCat("match __match_cond");
StrCat("{");

bool has_default_case = false;
auto body = llvm::cast<clang::CompoundStmt>(stmt->getBody());
auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody());
assert(body);

break_with_explicit_label_ = true;
for (auto it = body->body_begin(), end = body->body_end(); it != end;) {
if (auto switch_case = clang::dyn_cast<clang::SwitchCase>(*it)) {
if (clang::isa<clang::CaseStmt>(switch_case)) {
StrCat("v if v == ");
} else {
has_default_case = true;
}
VisitSwitchCase(switch_case);
++it;
}
bool has_fallthrough = SwitchHasFallthrough(stmt);

while (it != end && !clang::isa<clang::SwitchCase>(*it)) {
Convert(*it);
++it;
}
if (has_fallthrough) {
StrCat("switch!(match ", ToString(stmt->getCond()), " {");
} else {
StrCat("'switch: {");
StrCat(std::format("let __match_cond = {};", ToString(stmt->getCond())));
StrCat("match __match_cond");
StrCat("{");
}

PushBreakTarget push(break_target_stack_, has_fallthrough
? BreakTarget::FallthroughSwitch
: BreakTarget::RegularSwitch);

clang::SwitchCase *default_case = nullptr;
for (auto *sc : GetTopLevelSwitchCases(stmt)) {
if (SwitchCaseContainsDefault(sc)) {
default_case = sc;
continue;
}
StrCat("v if v == ");
VisitSwitchCase(sc);
for (auto *t : GetSwitchCaseBody(body, sc)) {
Convert(t);
}
StrCat("},");
}

if (!has_default_case) {
if (default_case) {
StrCat("_ => {");
for (auto *t : GetSwitchCaseBody(body, default_case)) {
Convert(t);
}
StrCat("},");
} else {
StrCat(R"( _ => {})");
}
break_with_explicit_label_ = false;

StrCat("}");
StrCat("}");
if (has_fallthrough) {
StrCat("})");
} else {
StrCat("}");
StrCat("}");
}
return false;
}

Expand Down
30 changes: 29 additions & 1 deletion cpp2rust/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,35 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
clang::ASTContext &ctx_;
clang::FunctionDecl *curr_function_ = nullptr;
bool in_function_formals_ = false;
bool break_with_explicit_label_ = false;

enum class BreakTarget { Loop, RegularSwitch, FallthroughSwitch };
class BreakTargetStack {
public:
void push(BreakTarget t) { stack_.push(t); }
void pop() { stack_.pop(); }
bool isRegularSwitch() const {
return !stack_.empty() && stack_.top() == BreakTarget::RegularSwitch;
}

private:
std::stack<BreakTarget> stack_;
};
BreakTargetStack break_target_stack_;

class PushBreakTarget {
public:
PushBreakTarget(BreakTargetStack &stack, BreakTarget target)
: stack_(stack) {
stack_.push(target);
}
~PushBreakTarget() { stack_.pop(); }
PushBreakTarget(const PushBreakTarget &) = delete;
PushBreakTarget &operator=(const PushBreakTarget &) = delete;

private:
BreakTargetStack &stack_;
};

std::stack<clang::Expr *> curr_for_inc_;
std::stack<clang::QualType> curr_init_type_;

Expand Down
74 changes: 74 additions & 0 deletions cpp2rust/converter/converter_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,4 +660,78 @@ clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx) {
/*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride());
}

std::vector<clang::SwitchCase *>
GetTopLevelSwitchCases(clang::SwitchStmt *stmt) {
std::vector<clang::SwitchCase *> cases;
if (auto *body = llvm::dyn_cast<clang::CompoundStmt>(stmt->getBody())) {
for (auto *s : body->body()) {
if (auto *sc = clang::dyn_cast<clang::SwitchCase>(s)) {
cases.push_back(sc);
}
}
}
return cases;
}

bool SwitchCaseContainsDefault(clang::SwitchCase *c) {
for (clang::Stmt *cur = c;;) {
if (clang::isa<clang::DefaultStmt>(cur)) {
return true;
}
auto *sc = clang::dyn_cast<clang::SwitchCase>(cur);
if (!sc) {
return false;
}
cur = sc->getSubStmt();
}
return false;
}

static clang::Stmt *GetLastStmtOfSwitchCase(clang::SwitchCase *c) {
clang::Stmt *cur = c->getSubStmt();
while (auto *sc = clang::dyn_cast<clang::SwitchCase>(cur)) {
cur = sc->getSubStmt();
}
return cur;
}

std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
clang::SwitchCase *head) {
std::vector<clang::Stmt *> out;
out.push_back(GetLastStmtOfSwitchCase(head));
auto it = body->body_begin(), end = body->body_end();
while (it != end && *it != head) {
++it;
}
assert(it != end);
++it;
while (it != end && !clang::isa<clang::SwitchCase>(*it)) {
out.push_back(*it);
++it;
}
return out;
}

static bool SwitchCaseHasFallthrough(clang::Stmt *stmt) {
if (stmt) {
if (clang::isa<clang::BreakStmt>(stmt) ||
clang::isa<clang::ReturnStmt>(stmt)) {
return false;
}
}
return true;
}

bool SwitchHasFallthrough(clang::SwitchStmt *stmt) {
if (auto *body = clang::dyn_cast<clang::CompoundStmt>(stmt->getBody())) {
for (auto top_level_case : GetTopLevelSwitchCases(stmt)) {
auto arm = GetSwitchCaseBody(body, top_level_case);
if (arm.empty() || SwitchCaseHasFallthrough(arm.back())) {
return true;
}
}
}
return false;
}

} // namespace cpp2rust
10 changes: 10 additions & 0 deletions cpp2rust/converter/converter_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,14 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt);

clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx);

std::vector<clang::SwitchCase *>
GetTopLevelSwitchCases(clang::SwitchStmt *stmt);

bool SwitchCaseContainsDefault(clang::SwitchCase *c);

std::vector<clang::Stmt *> GetSwitchCaseBody(clang::CompoundStmt *body,
clang::SwitchCase *head);

bool SwitchHasFallthrough(clang::SwitchStmt *stmt);

} // namespace cpp2rust
15 changes: 15 additions & 0 deletions libcc2rs-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "libcc2rs-macros"
version = "0.1.0"
edition = "2021"

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1"
quote = "1"
syn = { version = "2", features = ["full", "visit-mut", "extra-traits"] }

[dev-dependencies]
trybuild = "1"
48 changes: 48 additions & 0 deletions libcc2rs-macros/src/goto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2022-present INESC-ID.
// Distributed under the MIT license that can be found in the LICENSE file.

use proc_macro::TokenStream;
use syn::parse::{Parse, ParseStream};
use syn::{parse_macro_input, Expr, Lifetime, Token};

use crate::state_machine::{Arm, GotoStateMachine, StateMachine};

pub fn expand(input: TokenStream) -> TokenStream {
let GotoBlockInput { arms } = parse_macro_input!(input as GotoBlockInput);
GotoStateMachine {
arms: arms
.into_iter()
.map(|a| Arm {
label: a.label.ident.to_string(),
body: a.body,
})
.collect(),
}
.emit()
.into()
}

struct GotoBlockInput {
arms: Vec<GotoArm>,
}

struct GotoArm {
label: Lifetime,
body: Expr,
}

impl Parse for GotoBlockInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut arms = Vec::new();
while !input.is_empty() {
let label: Lifetime = input.parse()?;
input.parse::<Token![=>]>()?;
let body: Expr = input.parse()?;
arms.push(GotoArm { label, body });
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(Self { arms })
}
}
Loading
Loading