Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 51 additions & 31 deletions cpp2rust/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,12 +965,8 @@ bool Converter::VisitReturnStmt(clang::ReturnStmt *stmt) {
}

void Converter::ConvertCondition(clang::Expr *cond) {
if (!cond->getType()->isBooleanType()) {
PushExprKind push(*this, ExprKind::RValue);
Convert(CreateConversionToBool(cond, ctx_));
return;
}
Convert(cond);
PushExprKind push(*this, ExprKind::RValue);
Convert(NormalizeToBool(cond, ctx_));
}

bool Converter::VisitIfStmt(clang::IfStmt *stmt) {
Expand Down Expand Up @@ -1741,6 +1737,41 @@ void Converter::ConvertIntegerToEnumeralCast(clang::Expr *to,
}
}

void Converter::ConvertIntegralToBooleanCast(clang::ImplicitCastExpr *expr) {
auto sub_expr = expr->getSubExpr();
auto *stripped = sub_expr->IgnoreParenImpCasts();

if (auto binop = clang::dyn_cast<clang::BinaryOperator>(stripped)) {
// Comparison already produces bool, no wrap needed.
if (binop->isComparisonOp()) {
Convert(sub_expr);
return;
}
// Distribute bool conversion to each argument of the logical op.
if (binop->isLogicalOp()) {
{
PushParen paren(*this);
ConvertCondition(binop->getLHS());
}
StrCat(binop->getOpcodeStr());
{
PushParen paren(*this);
ConvertCondition(binop->getRHS());
}
return;
}
}

PushParen paren(*this);
Convert(sub_expr);
StrCat(token::kDiff);
if (sub_expr->getType()->isEnumeralType()) {
StrCat(GetUnsafeTypeAsString(sub_expr->getType()), "::from(0)");
} else /* sub_expr->getType()->isIntegerType() */ {
StrCat(token::kZero);
}
}

bool Converter::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
auto *sub_expr = expr->getSubExpr();
auto type = expr->getType();
Expand Down Expand Up @@ -1816,20 +1847,7 @@ bool Converter::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
Convert(sub_expr);
break;
case clang::CastKind::CK_IntegralToBoolean:
if (auto binop = clang::dyn_cast<clang::BinaryOperator>(
sub_expr->IgnoreParenImpCasts())) {
// This already produces bool, no need for != 0
if (binop->isComparisonOp()) {
Convert(sub_expr);
break;
}
}

{
PushParen paren(*this);
Convert(sub_expr);
StrCat(token::kDiff, token::kZero);
}
ConvertIntegralToBooleanCast(expr);
break;
case clang::CastKind::CK_PointerToBoolean:
StrCat(token::kNot);
Expand Down Expand Up @@ -2130,6 +2148,18 @@ bool Converter::VisitUnaryOperator(clang::UnaryOperator *expr) {
Convert(sub_expr);
computed_expr_type_ = ComputedExprType::FreshValue;
break;
case clang::UO_LNot: {
bool needs_int_cast =
expr->getType()->isIntegerType() && !expr->getType()->isBooleanType();
PushParen paren_cast(*this, needs_int_cast);
StrCat(token::kNot);
ConvertCondition(sub_expr);
if (needs_int_cast) {
ConvertCast(expr->getType());
}
computed_expr_type_ = ComputedExprType::FreshValue;
break;
}
case clang::UO_Minus:
if (auto *literal = clang::dyn_cast<clang::IntegerLiteral>(sub_expr)) {
if (sub_expr->getType()->isUnsignedIntegerType()) {
Expand Down Expand Up @@ -2173,7 +2203,7 @@ void Converter::EmitStmtExprTail(clang::Expr *tail) { Convert(tail); }

bool Converter::VisitConditionalOperator(clang::ConditionalOperator *expr) {
StrCat(keyword::kIf);
Convert(expr->getCond());
ConvertCondition(expr->getCond());
{
PushBrace then_brace(*this);
if (expr->isLValue() && !isRValue() && !expr->getType()->isFunctionType()) {
Expand Down Expand Up @@ -2290,21 +2320,11 @@ bool Converter::VisitParenExpr(clang::ParenExpr *expr) {
}
}

// Add cast to avoid ambigous integers. Don't add cast if sub expression is a
// pointer dereference because we might want to mutate the dereferenced value.
bool should_add_integral_cast =
expr->getType()->isIntegralOrEnumerationType() && !isAddrOf() &&
!isVoid() && !clang::isa<clang::UnaryOperator>(expr->getSubExpr());
PushParen outer(*this, should_add_integral_cast);

{
PushParen inner(*this);
Convert(expr->getSubExpr());
}

if (should_add_integral_cast) {
ConvertCast(expr->getType());
}
return false;
}

Expand Down
2 changes: 2 additions & 0 deletions cpp2rust/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

void ConvertIntegerToEnumeralCast(clang::Expr *to, clang::Expr *from);

void ConvertIntegralToBooleanCast(clang::ImplicitCastExpr *expr);

virtual bool VisitImplicitCastExpr(clang::ImplicitCastExpr *expr);

virtual bool VisitExplicitCastExpr(clang::ExplicitCastExpr *expr);
Expand Down
27 changes: 25 additions & 2 deletions cpp2rust/converter/converter_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,32 @@ bool ContainsVAArgExpr(const clang::Stmt *stmt) {
return false;
}

clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx) {
clang::Expr *NormalizeToBool(clang::Expr *expr, clang::ASTContext &ctx) {
if (expr->getType()->isBooleanType()) {
return expr;
}

// If logical not returns integer, then craft a new logical not that returns
// bool.
if (auto bin = clang::dyn_cast<clang::UnaryOperator>(expr)) {
if (bin->getOpcode() == clang::UO_LNot) {
return clang::UnaryOperator::Create(
ctx, bin->getSubExpr(), clang::UO_LNot, ctx.BoolTy, clang::VK_PRValue,
clang::OK_Ordinary, clang::SourceLocation(), false,
clang::FPOptionsOverride());
}
}

// Either to pointer -> bool, or int -> bool.
clang::CastKind cast_kind;
if (expr->getType()->isPointerType()) {
cast_kind = clang::CK_PointerToBoolean;
} else /* expr->getType()->isIntegerType() */ {
cast_kind = clang::CK_IntegralToBoolean;
}

return clang::ImplicitCastExpr::Create(
ctx, ctx.BoolTy, clang::CK_IntegralToBoolean, expr,
ctx, ctx.BoolTy, cast_kind, expr,
/*BasePath=*/nullptr, clang::VK_PRValue, clang::FPOptionsOverride());
}

Expand Down
2 changes: 1 addition & 1 deletion cpp2rust/converter/converter_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ bool IsBuiltinVaCopy(const clang::CallExpr *expr);

bool ContainsVAArgExpr(const clang::Stmt *stmt);

clang::Expr *CreateConversionToBool(clang::Expr *expr, clang::ASTContext &ctx);
clang::Expr *NormalizeToBool(clang::Expr *expr, clang::ASTContext &ctx);

std::vector<clang::SwitchCase *>
GetTopLevelSwitchCases(clang::SwitchStmt *stmt);
Expand Down
17 changes: 13 additions & 4 deletions cpp2rust/converter/models/converter_refcount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ bool ConverterRefCount::ConvertIncAndDec(clang::UnaryOperator *expr) {

bool ConverterRefCount::VisitConditionalOperator(
clang::ConditionalOperator *expr) {
StrCat(keyword::kIf, ConvertRValue(expr->getCond()));
StrCat(keyword::kIf);
ConvertCondition(expr->getCond());
{
PushBrace then_brace(*this);
StrCat(ConvertFresh(expr->getTrueExpr()));
Expand Down Expand Up @@ -1046,9 +1047,12 @@ bool ConverterRefCount::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
}
}

if (expr->getCastKind() == clang::CastKind::CK_NullToPointer &&
expr->getType()->isFunctionPointerType()) {
StrCat("FnPtr::null()");
if (expr->getCastKind() == clang::CastKind::CK_NullToPointer) {
if (expr->getType()->isFunctionPointerType()) {
StrCat("FnPtr::null()");
} else {
StrCat("Default::default()");
}
computed_expr_type_ = ComputedExprType::FreshPointer;
return false;
}
Expand Down Expand Up @@ -1137,6 +1141,11 @@ bool ConverterRefCount::VisitExplicitCastExpr(clang::ExplicitCastExpr *expr) {
}
return false;
}
if (expr->getCastKind() == clang::CK_NullToPointer) {
StrCat("Default::default()");
computed_expr_type_ = ComputedExprType::FreshPointer;
return false;
}
switch (expr->getStmtClass()) {
case clang::Stmt::CXXReinterpretCastExprClass:
assert(expr->getType()->isPointerType() &&
Expand Down
5 changes: 2 additions & 3 deletions tests/ub/out/refcount/ub4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ use std::io::{Read, Seek, Write};
use std::os::fd::AsFd;
use std::rc::{Rc, Weak};
pub fn smaller_0(x1: Ptr<i32>, x2: Ptr<i32>) -> Ptr<i32> {
return if (({
return if ({
let _lhs = (x1.read());
_lhs < (x2.read())
}) as bool)
{
}) {
(x1).clone()
} else {
(x2).clone()
Expand Down
6 changes: 1 addition & 5 deletions tests/ub/out/unsafe/ub4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ use std::io::{Read, Seek, Write};
use std::os::fd::{AsFd, FromRawFd, IntoRawFd};
use std::rc::Rc;
pub unsafe fn smaller_0(x1: *mut i32, x2: *mut i32) -> *mut i32 {
return if (((*x1) < (*x2)) as bool) {
(x1)
} else {
(x2)
};
return if ((*x1) < (*x2)) { (x1) } else { (x2) };
}
pub fn main() {
unsafe {
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/bool_condition_enum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <assert.h>

enum Code { CODE_OK = 0, CODE_ERR = 1, CODE_FATAL = 2 };

int main() {
Code code = CODE_OK;
Code err = CODE_ERR;

if (code) {
assert(false);
}
if (!code) {
assert(true);
}
if (err) {
assert(true);
}
if (!err) {
assert(false);
}

int t9 = !code;
assert(t9 == 1);

bool b4 = code;
assert(!b4);

return 0;
}
30 changes: 30 additions & 0 deletions tests/unit/bool_condition_enum_c.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <assert.h>
#include <stdbool.h>

enum Code { CODE_OK = 0, CODE_ERR = 1, CODE_FATAL = 2 };

int main() {
enum Code code = CODE_OK;
enum Code err = CODE_ERR;

if (code) {
assert(false);
}
if (!code) {
assert(true);
}
if (err) {
assert(true);
}
if (!err) {
assert(false);
}

int t9 = !code;
assert(t9 == 1);

bool b4 = code;
assert(!b4);

return 0;
}
63 changes: 63 additions & 0 deletions tests/unit/bool_condition_int.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include <assert.h>

int main() {
int n = 3;
int zero = 0;
unsigned u = 4;
unsigned long ul = 5;
long long ll = 6;
char ch = 'a';

if (n) {
assert(true);
}
if (!n) {
assert(false);
}
if (zero) {
assert(false);
}
if (!zero) {
assert(true);
}

if (u) {
assert(true);
}
if (ul) {
assert(true);
}
if (ll) {
assert(true);
}
if (ch) {
assert(true);
}

int loop_count = 0;
int counter = 3;
while (counter) {
--counter;
++loop_count;
}
assert(loop_count == 3);

for (int i = 5; i; --i) {
++loop_count;
}
assert(loop_count == 8);

int t = n ? 100 : 200;
assert(t == 100);
int t2 = zero ? 100 : 200;
assert(t2 == 200);
int t7 = !n;
assert(t7 == 0);
int t8 = !zero;
assert(t8 == 1);

bool b1 = n;
assert(b1);

return 0;
}
Loading
Loading