diff --git a/src/ExpressiveSharp.Generator/PolyfillInterceptorGenerator.cs b/src/ExpressiveSharp.Generator/PolyfillInterceptorGenerator.cs index 98650de9..55492a6c 100644 --- a/src/ExpressiveSharp.Generator/PolyfillInterceptorGenerator.cs +++ b/src/ExpressiveSharp.Generator/PolyfillInterceptorGenerator.cs @@ -340,6 +340,10 @@ private static string GetFileTag(string sourcePath) {{delegateFqn}} __func, params global::ExpressiveSharp.IExpressionTreeTransformer[] transformers) { + global::System.ArgumentNullException.ThrowIfNull(transformers); + for (var __i = 0; __i < transformers.Length; __i++) + if (transformers[__i] is null) + throw new global::System.ArgumentNullException(nameof(transformers), $"transformers[{__i}] is null"); {{emitResult.Body}} global::System.Linq.Expressions.Expression result = __lambda; foreach (var t in transformers) result = t.Transform(result); return (global::System.Linq.Expressions.Expression<{{delegateFqn}}>)result; @@ -598,6 +602,14 @@ private static string MethodId(string op, string fileTag, int line, int col) typeParamNames[i] = $"T{i}"; typeParams = $"<{string.Join(", ", typeParamNames)}>"; + // The unsubstituted method type parameter symbols carry per-position identity + // even when two distinct parameters happen to substitute to the same concrete + // type (e.g. GroupJoin's TInner=TKey=int) — used by the interceptor signature + // emission below. The substituted type entries remain so EmitLambdaBody (which + // sees substituted parameter symbols) can still resolve anonymous return types + // and element types into Tn aliases. + for (int i = 0; i < method.TypeParameters.Length; i++) + typeAliases[method.TypeParameters[i]] = typeParamNames[i]; if (!typeAliases.ContainsKey(elemSym)) typeAliases[elemSym] = typeParamNames[0]; for (int i = 0; i < methodTypeArgs.Length; i++) @@ -611,18 +623,14 @@ private static string MethodId(string op, string fileTag, int line, int col) else elemRef = elemFqn; + var origMethod = method.OriginalDefinition; var funcFqnGenerics = new string[funcParamIndices.Count]; for (int fi = 0; fi < funcParamIndices.Count; fi++) { - var funcTypeArgs = ((INamedTypeSymbol)method.Parameters[funcParamIndices[fi]].Type).TypeArguments; + var funcTypeArgs = ((INamedTypeSymbol)origMethod.Parameters[funcParamIndices[fi]].Type).TypeArguments; var sigParts = new string[funcTypeArgs.Length]; for (int i = 0; i < funcTypeArgs.Length; i++) - { - if (typeAliases.TryGetValue(funcTypeArgs[i], out var gp)) - sigParts[i] = gp; - else - sigParts[i] = funcTypeArgs[i].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - } + sigParts[i] = ResolveTypeFqn(funcTypeArgs[i], typeAliases); funcFqnGenerics[fi] = "global::System.Func<" + string.Join(", ", sigParts) + ">"; delegateFqns[fi] = funcFqnGenerics[fi]; } @@ -633,12 +641,8 @@ private static string MethodId(string op, string fileTag, int line, int col) if (isRewritableReturn) { - if (typeAliases.TryGetValue(returnElemType!, out var retParam)) - returnRef = retParam; - else - // Composite return types like IGrouping need alias substitution - // (anonymous types have no nameable form in C# source). - returnRef = ResolveTypeFqn(returnElemType!, typeAliases); + var origReturnElem = ((INamedTypeSymbol)origMethod.ReturnType).TypeArguments[0]; + returnRef = ResolveTypeFqn(origReturnElem, typeAliases); } else { @@ -659,7 +663,7 @@ private static string MethodId(string op, string fileTag, int line, int col) } else { - var paramType = method.Parameters[i].Type; + var paramType = origMethod.Parameters[i].Type; var paramTypeFqn = ResolveTypeFqn(paramType, typeAliases); var paramName = method.Parameters[i].Name; interceptorParams.Add($"{paramTypeFqn} {paramName}"); diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/IndexRangeTests.cs b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/IndexRangeTests.cs index 5cc3be45..b573b841 100644 --- a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/IndexRangeTests.cs +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/IndexRangeTests.cs @@ -51,4 +51,24 @@ class C { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + + [TestMethod] + public void RangeSliceOnString_DoesNotEmitUnsupportedOperationDiagnostic() + { + var compilation = CreateCompilation( + """ + namespace Foo { + class C { + public string Label { get; set; } + + [Expressive] + public string FirstFive => Label[..5]; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.IsFalse( + result.Diagnostics.Any(d => d.Id == "EXP0008" && d.GetMessage().Contains("ImplicitIndexerReference"))); + } } diff --git a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ListPatternTests.cs b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ListPatternTests.cs index 128026ad..c6996263 100644 --- a/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ListPatternTests.cs +++ b/tests/ExpressiveSharp.Generator.Tests/ExpressiveGenerator/ListPatternTests.cs @@ -50,4 +50,30 @@ class C { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + + [TestMethod] + public void ListPattern_OnArray_DoesNotEmitUnsupportedOperationDiagnostic() + { + var compilation = CreateCompilation( + """ + namespace Foo { + class C { + public int[] Items { get; set; } + + [Expressive] + public string Shape() => Items switch + { + [] => "empty", + [_] => "one", + _ => "many", + }; + } + } + """); + var result = RunExpressiveGenerator(compilation); + + Assert.IsFalse( + result.Diagnostics.Any(d => d.Id == "EXP0008" && d.GetMessage().Contains("ListPattern")), + "Arrays have Length and indexer; list patterns on int[] should be supported."); + } } diff --git a/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.GroupJoin_AnonymousResultSelector_GeneratesGenericInterceptor.verified.txt b/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.GroupJoin_AnonymousResultSelector_GeneratesGenericInterceptor.verified.txt index 4f506702..02adaf02 100644 --- a/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.GroupJoin_AnonymousResultSelector_GeneratesGenericInterceptor.verified.txt +++ b/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.GroupJoin_AnonymousResultSelector_GeneratesGenericInterceptor.verified.txt @@ -11,7 +11,7 @@ namespace ExpressiveSharp.Generated.Interceptors global::System.Collections.Generic.IEnumerable inner, global::System.Func __func1, global::System.Func __func2, - global::System.Func, T3> __func3) + global::System.Func, T3> __func3) { // Source: o => o.CustomerId var i361d11c19a_p_o = global::System.Linq.Expressions.Expression.Parameter(typeof(T0), "o"); @@ -28,7 +28,7 @@ namespace ExpressiveSharp.Generated.Interceptors var i361d11c19c_expr_2 = global::System.Linq.Expressions.Expression.Call(global::System.Linq.Enumerable.First(global::System.Linq.Enumerable.Where(typeof(global::System.Linq.Enumerable).GetMethods(global::System.Reflection.BindingFlags.Public | global::System.Reflection.BindingFlags.NonPublic | global::System.Reflection.BindingFlags.Static), m => m.Name == "Count" && m.IsGenericMethodDefinition && m.GetGenericArguments().Length == 1 && m.GetParameters().Length == 1 && m.GetParameters()[0].ParameterType.IsGenericType && !m.GetParameters()[0].ParameterType.IsGenericParameter)).MakeGenericMethod(typeof(T1)), new global::System.Linq.Expressions.Expression[] { i361d11c19c_p_cs }); // cs.Count() var i361d11c19c_expr_3 = typeof(T3).GetConstructors()[0]; var i361d11c19c_expr_0 = global::System.Linq.Expressions.Expression.New(i361d11c19c_expr_3, new global::System.Linq.Expressions.Expression[] { i361d11c19c_expr_1, i361d11c19c_expr_2 }, new global::System.Reflection.MemberInfo[] { typeof(T3).GetProperty("CustomerId"), typeof(T3).GetProperty("Count") }); - var __lambda3 = global::System.Linq.Expressions.Expression.Lambda, T3>>(i361d11c19c_expr_0, i361d11c19c_p_o, i361d11c19c_p_cs); + var __lambda3 = global::System.Linq.Expressions.Expression.Lambda, T3>>(i361d11c19c_expr_0, i361d11c19c_p_o, i361d11c19c_p_cs); return (global::ExpressiveSharp.IExpressiveQueryable)(object) global::ExpressiveSharp.ExpressiveQueryableExtensions.AsExpressive( global::System.Linq.Queryable.GroupJoin( diff --git a/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.cs b/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.cs index d9f49d0a..af18ac2a 100644 --- a/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.cs +++ b/tests/ExpressiveSharp.Generator.Tests/PolyfillInterceptorGenerator/JoinTests.cs @@ -1,3 +1,5 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.VisualStudio.TestTools.UnitTesting; using VerifyMSTest; using ExpressiveSharp.Generator.Tests.Infrastructure; @@ -136,4 +138,45 @@ public void Run(System.Linq.IQueryable orders, System.Collections.Generic return Verifier.Verify(result.GeneratedTrees[0].GetText().ToString()); } + + [TestMethod] + public void GroupJoin_AnonymousResultSelector_GeneratedInterceptorCompiles() + { + var source = + """ + using ExpressiveSharp; + + namespace TestNs + { + class Order { public int CustomerId { get; set; } } + class Customer { public int Id { get; set; } public string Name { get; set; } } + class TestClass + { + public void Run(System.Linq.IQueryable orders, System.Collections.Generic.IEnumerable customers) + { + orders.AsExpressive() + .GroupJoin(customers, + o => o.CustomerId, + c => c.Id, + (o, cs) => new { o.CustomerId, Count = cs.Count() }) + .ToList(); + } + } + } + """; + + var compilation = CreateCompilation(source); + var subject = new global::ExpressiveSharp.Generator.PolyfillInterceptorGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver + .Create(subject) + .WithUpdatedParseOptions((CSharpParseOptions)compilation.SyntaxTrees.First().Options) + .RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out _); + + var errors = outputCompilation.GetDiagnostics() + .Where(d => d.Severity == DiagnosticSeverity.Error) + .Where(d => d.Id != "CS9137") + .ToList(); + + Assert.AreEqual(0, errors.Count, string.Join("\n", errors.Select(d => d.ToString()))); + } } diff --git a/tests/ExpressiveSharp.IntegrationTests/Tests/ExpansionEdgeCasesTests.cs b/tests/ExpressiveSharp.IntegrationTests/Tests/ExpansionEdgeCasesTests.cs new file mode 100644 index 00000000..237ff6a5 --- /dev/null +++ b/tests/ExpressiveSharp.IntegrationTests/Tests/ExpansionEdgeCasesTests.cs @@ -0,0 +1,99 @@ +using System.Linq.Expressions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ExpressiveSharp.IntegrationTests.Tests; + +[TestClass] +public class ExpansionEdgeCasesTests +{ + [TestMethod] + public void VirtualMethod_ExpansionPreservesPolymorphicDispatch() + { + var derived = new VirtualDispatchDerived { Id = 7, Name = "x" }; + var directCall = ((VirtualDispatchBase)derived).Describe(); + + Expression> expr = b => b.Describe(); + var expanded = (Expression>)expr.ExpandExpressives(); + var fromExpansion = expanded.Compile()(derived); + + Assert.AreEqual(directCall, fromExpansion); + } + + [TestMethod] + public void Polyfill_TypePatternSwitch_WithNullArm_BuildsExpression() + { + var expr = ExpressionPolyfill.Create((object o) => o switch + { + int i => i + 1, + string s => s.Length, + null => -1, + _ => 0, + }); + + var fn = expr.Compile(); + Assert.AreEqual(6, fn(5)); + Assert.AreEqual(5, fn("hello")); + Assert.AreEqual(-1, fn(null!)); + } + + [TestMethod] + public void Polyfill_LambdaCapturingTupleLocal_BuildsExpression() + { + var p = (X: 3, Y: 4); + var expr = ExpressionPolyfill.Create((int n) => n + p.X + p.Y); + Assert.AreEqual(8, expr.Compile()(1)); + } + + [TestMethod] + public void Polyfill_LambdaWithListPattern_IsIntercepted() + { + var expr = ExpressionPolyfill.Create((int[] a) => a switch + { + [] => "empty", + [_] => "one", + _ => "many", + }); + + var fn = expr.Compile(); + Assert.AreEqual("empty", fn(Array.Empty())); + Assert.AreEqual("one", fn(new[] { 1 })); + Assert.AreEqual("many", fn(new[] { 1, 2 })); + } + + [Ignore("ExpressiveReplacer has no recursion guard; expansion currently throws StackOverflowException, which is uncatchable and terminates the test runner.")] + [TestMethod] + public void RecursiveExpressiveMember_ExpandsWithoutCrashingTheProcess() + { + Expression> expr = t => t.Sum; + var expanded = expr.ExpandExpressives(); + Assert.IsNotNull(expanded); + } +} + +public class VirtualDispatchBase +{ + public int Id { get; set; } + + [Expressive] + public virtual string Describe() => $"base#{Id}"; +} + +public class VirtualDispatchDerived : VirtualDispatchBase +{ + public string? Name { get; set; } + + [Expressive] + public override string Describe() => $"derived#{Id}/{Name}"; +} + +public class RecursiveTree +{ + public int Value { get; set; } + public RecursiveTree? Left { get; set; } + public RecursiveTree? Right { get; set; } + + [Expressive] + public int Sum => Value + + (Left == null ? 0 : Left.Sum) + + (Right == null ? 0 : Right.Sum); +} diff --git a/tests/ExpressiveSharp.Tests/Extensions/ExpressionPolyfillTests.cs b/tests/ExpressiveSharp.Tests/Extensions/ExpressionPolyfillTests.cs new file mode 100644 index 00000000..0b317a05 --- /dev/null +++ b/tests/ExpressiveSharp.Tests/Extensions/ExpressionPolyfillTests.cs @@ -0,0 +1,12 @@ +namespace ExpressiveSharp.Tests.Extensions; + +[TestClass] +public class ExpressionPolyfillTests +{ + [TestMethod] + public void Create_NullTransformerArgument_ThrowsArgumentNullException() + { + Assert.ThrowsExactly(() => + ExpressionPolyfill.Create((int n) => n + 1, (IExpressionTreeTransformer)null!)); + } +}