Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 107 additions & 65 deletions cpp2rust/converter/converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,15 @@ bool Converter::VisitRecordType(clang::RecordType *type) {
auto *decl = type->getDecl();
if (auto lambda = clang::dyn_cast<clang::CXXRecordDecl>(decl)) {
if (lambda->isLambda()) {
auto call_op = lambda->getLambdaCallOperator();
StrCat("Rc<dyn Fn(");
for (auto p : call_op->parameters()) {
StrCat(std::format("{},", ToStringBase(p->getType())));
}
StrCat(")");
if (!call_op->getReturnType()->isVoidType()) {
StrCat("->");
StrCat(ToStringBase(call_op->getReturnType()));
if (in_function_formals_) {
StrCat(
ConvertFunctionPointerType(lambda->getLambdaCallOperator()
->getType()
->getAs<clang::FunctionProtoType>(),
FnProtoType::LambdaCallOperator));
} else {
StrCat("_");
}
StrCat(">");
return false;
}
}
Expand Down Expand Up @@ -222,24 +220,25 @@ bool Converter::VisitLValueReferenceType(clang::LValueReferenceType *type) {
return Convert(pointee_type);
}

void Converter::ConvertFunctionPointerType(clang::PointerType *type) {
auto proto = type->getPointeeType()->getAs<clang::FunctionProtoType>();
assert(proto && "Type should be a function prototype");

StrCat("Rc<dyn Fn(");
std::string
Converter::ConvertFunctionPointerType(const clang::FunctionProtoType *proto,
FnProtoType kind) {
std::string result =
(kind == FnProtoType::LambdaCallOperator ? "impl Fn(" : "fn(");
for (auto p_ty : proto->param_types()) {
StrCat(std::format("{},", ToString(p_ty)));
result += ToString(p_ty) + ",";
}
StrCat(")");
result += ")";
if (!proto->getReturnType()->isVoidType()) {
StrCat(std::format("-> {}", ToString(proto->getReturnType())));
result += std::format(" -> {}", ToString(proto->getReturnType()));
}
StrCat(">");
return result;
}

bool Converter::VisitPointerType(clang::PointerType *type) {
if (type->getPointeeType()->getAs<clang::FunctionProtoType>()) {
ConvertFunctionPointerType(type);
if (auto proto = type->getPointeeType()->getAs<clang::FunctionProtoType>()) {
StrCat(std::format("Option<{} {}>", keyword_unsafe_,
ConvertFunctionPointerType(proto)));
return false;
}

Expand Down Expand Up @@ -423,6 +422,9 @@ bool Converter::ConvertVarDeclSkipInit(clang::VarDecl *decl) {
}

bool Converter::ConvertLambdaVarDecl(clang::VarDecl *decl) {
if (decl->getType()->isFunctionPointerType()) {
return false;
}
if (decl->hasInit()) {
if (clang::isa<clang::LambdaExpr>(
decl->getInit()->IgnoreUnlessSpelledInSource())) {
Expand Down Expand Up @@ -1377,6 +1379,17 @@ bool Converter::VisitCallExpr(clang::CallExpr *expr) {
return false;
}

void Converter::EmitFnPtrCall(clang::Expr *callee) {
StrCat(token::kOpenParen);
Convert(callee);
StrCat(").unwrap()");
}

void Converter::ConvertFunctionToFunctionPointer(
const clang::FunctionDecl *fn_decl) {
StrCat(std::format("Some({})", GetNamedDeclAsString(fn_decl)));
}

void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
clang::Expr *callee = expr->getCallee();
auto convert_param_ty = [&](clang::QualType param_type, clang::Expr *expr) {
Expand All @@ -1399,7 +1412,8 @@ void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
StrCat(token::kOpenParen);
StrCat(keyword_unsafe_);
StrCat(token::kOpenCurlyBracket);
const auto *function = expr->getCalleeDecl()->getAsFunction();
const auto *function =
expr->getCalleeDecl() ? expr->getCalleeDecl()->getAsFunction() : nullptr;
const clang::FunctionProtoType *proto = nullptr;

if (!function) {
Expand Down Expand Up @@ -1447,7 +1461,12 @@ void Converter::ConvertGenericCallExpr(clang::CallExpr *expr) {
}
}

Convert(callee);
if (proto && !function) {
EmitFnPtrCall(callee);
} else {
PushExprKind push(*this, ExprKind::Callee);
Convert(callee);
}
StrCat(token::kOpenParen);
for (unsigned i = 0; i < num_named_params && i < num_args; ++i) {
auto *arg = expr->getArg(i + arg_begin);
Expand Down Expand Up @@ -1668,7 +1687,15 @@ bool Converter::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
break;
}
case clang::CastKind::CK_FunctionToPointerDecay:
case clang::CastKind::CK_BuiltinFnToFnPtr:
case clang::CastKind::CK_BuiltinFnToFnPtr: {
if (isCallee()) {
Convert(sub_expr);
} else {
PushExprKind push(*this, ExprKind::AddrOf);
Convert(sub_expr);
}
break;
}
case clang::CastKind::CK_ConstructorConversion:
case clang::CastKind::CK_DerivedToBase:
Convert(sub_expr);
Expand All @@ -1692,7 +1719,11 @@ bool Converter::VisitImplicitCastExpr(clang::ImplicitCastExpr *expr) {
ConvertEqualsNullPtr(sub_expr);
break;
case clang::CastKind::CK_NullToPointer:
StrCat(keyword_default_);
if (type->isFunctionPointerType()) {
StrCat("None");
} else {
StrCat(keyword_default_);
}
computed_expr_type_ = ComputedExprType::FreshPointer;
break;
default:
Expand Down Expand Up @@ -1737,6 +1768,17 @@ bool Converter::VisitExplicitCastExpr(clang::ExplicitCastExpr *expr) {
if (expr->getType() == sub_expr->getType()) {
return Convert(sub_expr);
}
if (type->isFunctionPointerType() ||
sub_expr->getType()->isFunctionPointerType()) {
StrCat("std::mem::transmute::<");
Convert(sub_expr->getType());
StrCat(",");
Convert(type);
StrCat(">(");
Convert(sub_expr);
StrCat(")");
return false;
}
StrCat(token::kOpenParen);
Convert(sub_expr);
if (auto *unary_oper = clang::dyn_cast<clang::UnaryOperator>(sub_expr);
Expand Down Expand Up @@ -1963,12 +2005,12 @@ bool Converter::VisitConditionalOperator(clang::ConditionalOperator *expr) {
StrCat(keyword::kIf);
Convert(expr->getCond());
StrCat(token::kOpenCurlyBracket);
if (expr->isLValue() && !isRValue()) {
if (expr->isLValue() && !isRValue() && !expr->getType()->isFunctionType()) {
StrCat(token::kRef, keyword_mut_);
}
Convert(expr->getTrueExpr());
StrCat(token::kCloseCurlyBracket, keyword::kElse, token::kOpenCurlyBracket);
if (expr->isLValue() && !isRValue()) {
if (expr->isLValue() && !isRValue() && !expr->getType()->isFunctionType()) {
StrCat(token::kRef, keyword_mut_);
}
Convert(expr->getFalseExpr());
Expand Down Expand Up @@ -2022,35 +2064,23 @@ bool Converter::VisitDeclRefExpr(clang::DeclRefExpr *expr) {
return false;
}

if (auto function = clang::dyn_cast<clang::FunctionDecl>(decl)) {
if (auto *fn_decl = clang::dyn_cast<clang::FunctionDecl>(decl)) {
if (isAddrOf()) {
// Wrap unsafe function in safe closure because the Fn trait only accepts
// safe functions
std::string arguments;
for (unsigned i = 0; i < function->getNumParams(); ++i) {
arguments += (i ? ", a" : "a") + std::to_string(i);
}
StrCat("Rc::new", token::kOpenParen);
StrCat(std::format("|{}|", arguments));
StrCat(keyword_unsafe_, token::kOpenCurlyBracket);
StrCat(str);
StrCat(token::kOpenParen);
StrCat(arguments);
StrCat(token::kCloseParen);
StrCat(token::kCloseCurlyBracket);
StrCat(token::kCloseParen);
ConvertFunctionToFunctionPointer(fn_decl);
return false;
}
}

if (auto var_decl = clang::dyn_cast<clang::VarDecl>(decl)) {
if (auto init = var_decl->getInit()) {
if (auto lambda = clang::dyn_cast<clang::LambdaExpr>(
init->IgnoreUnlessSpelledInSource())) {
StrCat(token::kOpenParen);
VisitLambdaExpr(lambda);
StrCat(token::kCloseParen);
return false;
if (!var_decl->getType()->isFunctionPointerType()) {
if (auto init = var_decl->getInit()) {
if (auto lambda = clang::dyn_cast<clang::LambdaExpr>(
init->IgnoreUnlessSpelledInSource())) {
StrCat(token::kOpenParen);
VisitLambdaExpr(lambda);
StrCat(token::kCloseParen);
return false;
}
}
}
}
Expand Down Expand Up @@ -2509,6 +2539,9 @@ bool Converter::VisitCXXDefaultArgExpr(clang::CXXDefaultArgExpr *expr) {
}

bool Converter::VisitLambdaExpr(clang::LambdaExpr *expr) {
if (isAddrOf() && expr->capture_size() == 0) {
StrCat("Some");
}
StrCat(token::kOpenParen);
StrCat("|");
for (auto p : expr->getLambdaClass()->getLambdaCallOperator()->parameters()) {
Expand Down Expand Up @@ -2636,23 +2669,10 @@ bool Converter::VisitCXXStdInitializerListExpr(
return false;
}

std::string
Converter::GetFunctionPointerDefaultAsString(clang::QualType qual_type) {
std::string ret;
auto proto = qual_type->getPointeeType()->getAs<clang::FunctionProtoType>();
assert(proto);
ret = "Rc::new(|";
for (unsigned i = 0; i < proto->getNumParams(); ++i) {
ret += "_,";
}
ret += R"(| { panic!("ub: uninit function pointer") }))";
return ret;
}

std::string Converter::GetDefaultAsString(clang::QualType qual_type) {
if (qual_type->isPointerType()) {
if (qual_type->getPointeeType()->isFunctionType()) {
return GetFunctionPointerDefaultAsString(qual_type);
return "None";
} else {
computed_expr_type_ = ComputedExprType::FreshPointer;
return keyword_default_;
Expand Down Expand Up @@ -2800,6 +2820,16 @@ void Converter::ConvertVarInit(clang::QualType qual_type, clang::Expr *expr) {
StrCat(keyword_mut_);
}
}
if (qual_type->isFunctionPointerType()) {
if (auto *lambda = clang::dyn_cast<clang::LambdaExpr>(
expr->IgnoreUnlessSpelledInSource())) {
PushExprKind push(*this, ExprKind::AddrOf);
curr_init_type_.push(qual_type);
VisitLambdaExpr(lambda);
curr_init_type_.pop();
return;
}
}
auto *ignore_casts = expr->IgnoreCasts();
// FIXME: this looks very complicated
if (auto *ctor = clang::dyn_cast<clang::CXXConstructExpr>(ignore_casts);
Expand Down Expand Up @@ -2845,7 +2875,8 @@ void Converter::ConvertUnsignedArithOperand(clang::Expr *expr,
void Converter::ConvertEqualsNullPtr(clang::Expr *expr) {
StrCat("(");
Convert(expr);
if (IsUniquePtr(expr->getType())) {
if (IsUniquePtr(expr->getType()) ||
expr->getType()->isFunctionPointerType()) {
StrCat(").is_none()");
} else {
StrCat(").is_null()");
Expand Down Expand Up @@ -3229,6 +3260,13 @@ void Converter::PlaceholderCtx::dump() const {

std::string Converter::ConvertPlaceholder(clang::Expr *expr, clang::Expr *arg,
const PlaceholderCtx &ph_ctx) {
if (arg->getType()->isFunctionPointerType()) {
PushExprKind push(*this, ExprKind::Callee);
Buffer buf(*this);
Convert(arg);
return std::move(buf).str();
}

if (ph_ctx.needs_materialization()) {
auto materialized = ph_ctx.materialize_ctx->GetOrMaterialize(
static_cast<unsigned>(ph_ctx.materialize_idx),
Expand Down Expand Up @@ -3376,6 +3414,10 @@ bool Converter::isVoid() const {
return curr_expr_kind_.empty() || curr_expr_kind_.back() == ExprKind::Void;
}

bool Converter::isCallee() const {
return !curr_expr_kind_.empty() && curr_expr_kind_.back() == ExprKind::Callee;
}

void Converter::SetFresh() {
switch (computed_expr_type_) {
case ComputedExprType::Value:
Expand Down
17 changes: 14 additions & 3 deletions cpp2rust/converter/converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

virtual bool VisitPointerType(clang::PointerType *type);

void ConvertFunctionPointerType(clang::PointerType *type);
enum class FnProtoType { LambdaCallOperator, FnPtr };

virtual std::string
ConvertFunctionPointerType(const clang::FunctionProtoType *proto,
FnProtoType kind = FnProtoType::FnPtr);

virtual bool VisitDecayedType(clang::DecayedType *type);

Expand Down Expand Up @@ -201,6 +205,11 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

void ConvertGenericCallExpr(clang::CallExpr *expr);

virtual void EmitFnPtrCall(clang::Expr *callee);

virtual void
ConvertFunctionToFunctionPointer(const clang::FunctionDecl *fn_decl);

virtual void ConvertPrintf(clang::CallExpr *expr);

void ConvertVAArgCall(clang::CallExpr *expr);
Expand Down Expand Up @@ -334,8 +343,6 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
virtual bool Convert(clang::Stmt *stmt);
virtual bool Convert(clang::Expr *expr);

std::string GetFunctionPointerDefaultAsString(clang::QualType qual_type);

virtual std::string GetDefaultAsString(clang::QualType qual_type);

virtual std::string GetDefaultAsStringFallback(clang::QualType qual_type);
Expand Down Expand Up @@ -472,6 +479,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
static std::unordered_set<std::string> abstract_structs_;

enum class ExprKind : uint8_t {
Callee,
LValue,
RValue,
XValue,
Expand All @@ -482,6 +490,8 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {

static const char *expr_kind_to_string(ExprKind kind) {
switch (kind) {
case ExprKind::Callee:
return "Callee";
case ExprKind::LValue:
return "LValue";
case ExprKind::RValue:
Expand All @@ -505,6 +515,7 @@ class Converter : public clang::RecursiveASTVisitor<Converter> {
bool isAddrOf() const;
bool isObject() const;
bool isVoid() const;
bool isCallee() const;

void dump_expr_kinds();

Expand Down
Loading
Loading