Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions src/ExpressiveSharp.Generator/PolyfillInterceptorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ private static string GetFileTag(string sourcePath)
{{delegateFqn}} __func,
params global::ExpressiveSharp.IExpressionTreeTransformer[] transformers)
{
global::System.ArgumentNullException.ThrowIfNull(transformers);
for (var __i = 0; __i < transformers.Length; __i++)
if (transformers[__i] is null)
throw new global::System.ArgumentNullException(nameof(transformers), $"transformers[{__i}] is null");
{{emitResult.Body}} global::System.Linq.Expressions.Expression result = __lambda;
foreach (var t in transformers) result = t.Transform(result);
return (global::System.Linq.Expressions.Expression<{{delegateFqn}}>)result;
Expand Down Expand Up @@ -598,6 +602,14 @@ private static string MethodId(string op, string fileTag, int line, int col)
typeParamNames[i] = $"T{i}";
typeParams = $"<{string.Join(", ", typeParamNames)}>";

// The unsubstituted method type parameter symbols carry per-position identity
// even when two distinct parameters happen to substitute to the same concrete
// type (e.g. GroupJoin's TInner=TKey=int) — used by the interceptor signature
// emission below. The substituted type entries remain so EmitLambdaBody (which
// sees substituted parameter symbols) can still resolve anonymous return types
// and element types into Tn aliases.
for (int i = 0; i < method.TypeParameters.Length; i++)
typeAliases[method.TypeParameters[i]] = typeParamNames[i];
if (!typeAliases.ContainsKey(elemSym))
typeAliases[elemSym] = typeParamNames[0];
for (int i = 0; i < methodTypeArgs.Length; i++)
Expand All @@ -611,18 +623,14 @@ private static string MethodId(string op, string fileTag, int line, int col)
else
elemRef = elemFqn;

var origMethod = method.OriginalDefinition;
var funcFqnGenerics = new string[funcParamIndices.Count];
for (int fi = 0; fi < funcParamIndices.Count; fi++)
{
var funcTypeArgs = ((INamedTypeSymbol)method.Parameters[funcParamIndices[fi]].Type).TypeArguments;
var funcTypeArgs = ((INamedTypeSymbol)origMethod.Parameters[funcParamIndices[fi]].Type).TypeArguments;
var sigParts = new string[funcTypeArgs.Length];
for (int i = 0; i < funcTypeArgs.Length; i++)
{
if (typeAliases.TryGetValue(funcTypeArgs[i], out var gp))
sigParts[i] = gp;
else
sigParts[i] = funcTypeArgs[i].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
}
sigParts[i] = ResolveTypeFqn(funcTypeArgs[i], typeAliases);
funcFqnGenerics[fi] = "global::System.Func<" + string.Join(", ", sigParts) + ">";
delegateFqns[fi] = funcFqnGenerics[fi];
}
Expand All @@ -633,12 +641,8 @@ private static string MethodId(string op, string fileTag, int line, int col)

if (isRewritableReturn)
{
if (typeAliases.TryGetValue(returnElemType!, out var retParam))
returnRef = retParam;
else
// Composite return types like IGrouping<TKey, AnonType> need alias substitution
// (anonymous types have no nameable form in C# source).
returnRef = ResolveTypeFqn(returnElemType!, typeAliases);
var origReturnElem = ((INamedTypeSymbol)origMethod.ReturnType).TypeArguments[0];
returnRef = ResolveTypeFqn(origReturnElem, typeAliases);
}
else
{
Expand All @@ -659,7 +663,7 @@ private static string MethodId(string op, string fileTag, int line, int col)
}
else
{
var paramType = method.Parameters[i].Type;
var paramType = origMethod.Parameters[i].Type;
var paramTypeFqn = ResolveTypeFqn(paramType, typeAliases);
var paramName = method.Parameters[i].Name;
interceptorParams.Add($"{paramTypeFqn} {paramName}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,24 @@ class C {

return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[TestMethod]
public void RangeSliceOnString_DoesNotEmitUnsupportedOperationDiagnostic()
{
var compilation = CreateCompilation(
"""
namespace Foo {
class C {
public string Label { get; set; }

[Expressive]
public string FirstFive => Label[..5];
}
}
""");
var result = RunExpressiveGenerator(compilation);

Assert.IsFalse(
result.Diagnostics.Any(d => d.Id == "EXP0008" && d.GetMessage().Contains("ImplicitIndexerReference")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,30 @@ class C {

return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[TestMethod]
public void ListPattern_OnArray_DoesNotEmitUnsupportedOperationDiagnostic()
{
var compilation = CreateCompilation(
"""
namespace Foo {
class C {
public int[] Items { get; set; }

[Expressive]
public string Shape() => Items switch
{
[] => "empty",
[_] => "one",
_ => "many",
};
}
}
""");
var result = RunExpressiveGenerator(compilation);

Assert.IsFalse(
result.Diagnostics.Any(d => d.Id == "EXP0008" && d.GetMessage().Contains("ListPattern")),
"Arrays have Length and indexer; list patterns on int[] should be supported.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace ExpressiveSharp.Generated.Interceptors
global::System.Collections.Generic.IEnumerable<T1> inner,
global::System.Func<T0, T2> __func1,
global::System.Func<T1, T2> __func2,
global::System.Func<T0, global::System.Collections.Generic.IEnumerable<global::TestNs.Customer>, T3> __func3)
global::System.Func<T0, global::System.Collections.Generic.IEnumerable<T1>, T3> __func3)
{
// Source: o => o.CustomerId
var i361d11c19a_p_o = global::System.Linq.Expressions.Expression.Parameter(typeof(T0), "o");
Expand All @@ -28,7 +28,7 @@ namespace ExpressiveSharp.Generated.Interceptors
var i361d11c19c_expr_2 = global::System.Linq.Expressions.Expression.Call(global::System.Linq.Enumerable.First(global::System.Linq.Enumerable.Where(typeof(global::System.Linq.Enumerable).GetMethods(global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static), m => m.Name == "Count" && m.IsGenericMethodDefinition && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1 && m.GetParameters()[0].ParameterType.IsGenericType && !m.GetParameters()[0].ParameterType.IsGenericParameter)).MakeGenericMethod(typeof(T1)), new global::System.Linq.Expressions.Expression[] { i361d11c19c_p_cs }); // cs.Count()
var i361d11c19c_expr_3 = typeof(T3).GetConstructors()[0];
var i361d11c19c_expr_0 = global::System.Linq.Expressions.Expression.New(i361d11c19c_expr_3, new global::System.Linq.Expressions.Expression[] { i361d11c19c_expr_1, i361d11c19c_expr_2 }, new global::System.Reflection.MemberInfo[] { typeof(T3).GetProperty("CustomerId"), typeof(T3).GetProperty("Count") });
var __lambda3 = global::System.Linq.Expressions.Expression.Lambda<global::System.Func<T0, global::System.Collections.Generic.IEnumerable<global::TestNs.Customer>, T3>>(i361d11c19c_expr_0, i361d11c19c_p_o, i361d11c19c_p_cs);
var __lambda3 = global::System.Linq.Expressions.Expression.Lambda<global::System.Func<T0, global::System.Collections.Generic.IEnumerable<T1>, T3>>(i361d11c19c_expr_0, i361d11c19c_p_o, i361d11c19c_p_cs);
return (global::ExpressiveSharp.IExpressiveQueryable<T3>)(object)
global::ExpressiveSharp.ExpressiveQueryableExtensions.AsExpressive(
global::System.Linq.Queryable.GroupJoin(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using VerifyMSTest;
using ExpressiveSharp.Generator.Tests.Infrastructure;
Expand Down Expand Up @@ -136,4 +138,45 @@ public void Run(System.Linq.IQueryable<Order> orders, System.Collections.Generic

return Verifier.Verify(result.GeneratedTrees[0].GetText().ToString());
}

[TestMethod]
public void GroupJoin_AnonymousResultSelector_GeneratedInterceptorCompiles()
{
var source =
"""
using ExpressiveSharp;

namespace TestNs
{
class Order { public int CustomerId { get; set; } }
class Customer { public int Id { get; set; } public string Name { get; set; } }
class TestClass
{
public void Run(System.Linq.IQueryable<Order> orders, System.Collections.Generic.IEnumerable<Customer> customers)
{
orders.AsExpressive()
.GroupJoin(customers,
o => o.CustomerId,
c => c.Id,
(o, cs) => new { o.CustomerId, Count = cs.Count() })
.ToList();
}
}
}
""";

var compilation = CreateCompilation(source);
var subject = new global::ExpressiveSharp.Generator.PolyfillInterceptorGenerator();
GeneratorDriver driver = CSharpGeneratorDriver
.Create(subject)
.WithUpdatedParseOptions((CSharpParseOptions)compilation.SyntaxTrees.First().Options)
.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _);

var errors = outputCompilation.GetDiagnostics()
.Where(d => d.Severity == DiagnosticSeverity.Error)
.Where(d => d.Id != "CS9137")
.ToList();

Comment on lines +174 to +179
Assert.AreEqual(0, errors.Count, string.Join("\n", errors.Select(d => d.ToString())));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using System.Linq.Expressions;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace ExpressiveSharp.IntegrationTests.Tests;

[TestClass]
public class ExpansionEdgeCasesTests
{
[TestMethod]
public void VirtualMethod_ExpansionPreservesPolymorphicDispatch()
{
var derived = new VirtualDispatchDerived { Id = 7, Name = "x" };
var directCall = ((VirtualDispatchBase)derived).Describe();

Expression<Func<VirtualDispatchBase, string>> expr = b => b.Describe();
var expanded = (Expression<Func<VirtualDispatchBase, string>>)expr.ExpandExpressives();
var fromExpansion = expanded.Compile()(derived);

Assert.AreEqual(directCall, fromExpansion);
}

[TestMethod]
public void Polyfill_TypePatternSwitch_WithNullArm_BuildsExpression()
{
var expr = ExpressionPolyfill.Create((object o) => o switch
{
int i => i + 1,
string s => s.Length,
null => -1,
_ => 0,
});

var fn = expr.Compile();
Assert.AreEqual(6, fn(5));
Assert.AreEqual(5, fn("hello"));
Assert.AreEqual(-1, fn(null!));
}

[TestMethod]
public void Polyfill_LambdaCapturingTupleLocal_BuildsExpression()
{
var p = (X: 3, Y: 4);
var expr = ExpressionPolyfill.Create((int n) => n + p.X + p.Y);
Assert.AreEqual(8, expr.Compile()(1));
}

[TestMethod]
public void Polyfill_LambdaWithListPattern_IsIntercepted()
{
var expr = ExpressionPolyfill.Create((int[] a) => a switch
{
[] => "empty",
[_] => "one",
_ => "many",
});

var fn = expr.Compile();
Assert.AreEqual("empty", fn(Array.Empty<int>()));
Assert.AreEqual("one", fn(new[] { 1 }));
Assert.AreEqual("many", fn(new[] { 1, 2 }));
}

[Ignore("ExpressiveReplacer has no recursion guard; expansion currently throws StackOverflowException, which is uncatchable and terminates the test runner.")]
[TestMethod]
public void RecursiveExpressiveMember_ExpandsWithoutCrashingTheProcess()
{
Expression<Func<RecursiveTree, int>> expr = t => t.Sum;
var expanded = expr.ExpandExpressives();
Assert.IsNotNull(expanded);
}
}

public class VirtualDispatchBase
{
public int Id { get; set; }

[Expressive]
public virtual string Describe() => $"base#{Id}";
}

public class VirtualDispatchDerived : VirtualDispatchBase
{
public string? Name { get; set; }

[Expressive]
public override string Describe() => $"derived#{Id}/{Name}";
}

public class RecursiveTree
{
public int Value { get; set; }
public RecursiveTree? Left { get; set; }
public RecursiveTree? Right { get; set; }

[Expressive]
public int Sum => Value
+ (Left == null ? 0 : Left.Sum)
+ (Right == null ? 0 : Right.Sum);
}
12 changes: 12 additions & 0 deletions tests/ExpressiveSharp.Tests/Extensions/ExpressionPolyfillTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace ExpressiveSharp.Tests.Extensions;

[TestClass]
public class ExpressionPolyfillTests
{
[TestMethod]
public void Create_NullTransformerArgument_ThrowsArgumentNullException()
{
Assert.ThrowsExactly<ArgumentNullException>(() =>
ExpressionPolyfill.Create((int n) => n + 1, (IExpressionTreeTransformer)null!));
}
}
Loading