forked from mjsabby/PInvokeCompiler
/
PInvokeMethodMetadataTraverser.cs
144 lines (121 loc) · 5.96 KB
/
PInvokeMethodMetadataTraverser.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
//-----------------------------------------------------------------------
// <copyright file="PInvokeMethodMetadataTraverser.cs" company="Microsoft">
// Copyright (c) Microsoft. All rights reserved.
// </copyright>
//-----------------------------------------------------------------------
namespace PInvokeCompiler
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.Cci;
internal sealed class PInvokeMethodMetadataTraverser : MetadataTraverser, IPInvokeMethodsProvider
{
private readonly Dictionary<ITypeDefinition, List<IMethodDefinition>> typeDefinitionTable = new Dictionary<ITypeDefinition, List<IMethodDefinition>>();
private readonly Dictionary<ITypeDefinition, HashSet<IModuleReference>> moduleRefsTable = new Dictionary<ITypeDefinition, HashSet<IModuleReference>>();
private readonly ITypeReference skipTypeReference;
public PInvokeMethodMetadataTraverser(ITypeReference skipTypeReference)
{
this.skipTypeReference = skipTypeReference;
}
public override void TraverseChildren(IMethodDefinition methodDefinition)
{
if (methodDefinition.IsPlatformInvoke)
{
if (TypeHelper.TypesAreEquivalent(methodDefinition.ContainingTypeDefinition, this.skipTypeReference))
{
return;
}
if (!IsReturnTypeSupported(methodDefinition))
{
throw new Exception($"Return type {methodDefinition.Type} is not supported for marshalling");
}
if (!methodDefinition.Parameters.All(IsParameterSupported))
{
throw new Exception($"Parameter types {methodDefinition} are not supported for marshalling");
}
var typeDefinition = methodDefinition.ContainingTypeDefinition;
List<IMethodDefinition> methodDefinitions;
if (!this.typeDefinitionTable.TryGetValue(typeDefinition, out methodDefinitions))
{
methodDefinitions = new List<IMethodDefinition>();
this.typeDefinitionTable.Add(typeDefinition, methodDefinitions);
}
HashSet<IModuleReference> moduleRefs;
if (!this.moduleRefsTable.TryGetValue(typeDefinition, out moduleRefs))
{
moduleRefs = new HashSet<IModuleReference>();
this.moduleRefsTable.Add(typeDefinition, moduleRefs);
}
methodDefinitions.Add(methodDefinition);
moduleRefs.Add(methodDefinition.PlatformInvokeData.ImportModule);
}
}
public IEnumerable<IMethodDefinition> RetrieveMethodDefinitions(ITypeDefinition typeDefinition)
{
List<IMethodDefinition> methods;
return this.typeDefinitionTable.TryGetValue(typeDefinition, out methods) ? methods : Enumerable.Empty<IMethodDefinition>();
}
public IEnumerable<IModuleReference> RetrieveModuleRefs(ITypeDefinition typeDefinition)
{
HashSet<IModuleReference> moduleRefs;
return this.moduleRefsTable.TryGetValue(typeDefinition, out moduleRefs) ? moduleRefs : Enumerable.Empty<IModuleReference>();
}
private static bool IsReturnTypeSupported(IMethodDefinition methodDefinition)
{
if (methodDefinition.ReturnValueIsMarshalledExplicitly)
{
var unmanagedType = methodDefinition.ReturnValueMarshallingInformation.UnmanagedType;
return methodDefinition.Type.IsString() && (unmanagedType == UnmanagedType.LPWStr || unmanagedType == UnmanagedType.LPStr);
}
var returnType = methodDefinition.Type;
if (returnType.TypeCode == PrimitiveTypeCode.Boolean || returnType.IsBlittable() || returnType.IsDelegate() || returnType.IsString())
{
return true;
}
return false;
}
private static bool IsParameterSupported(IParameterDefinition parameterDefinition)
{
var parameterType = parameterDefinition.Type;
// special short-circuit for specific marshalling.
if (parameterDefinition.IsMarshalledExplicitly)
{
var unmanagedType = parameterDefinition.MarshallingInformation.UnmanagedType;
switch (unmanagedType)
{
case UnmanagedType.LPWStr:
case UnmanagedType.LPStr:
return parameterType.IsString() || parameterType.IsStringArray();
case UnmanagedType.LPArray:
if (parameterType.IsBlittable())
{
return true;
}
if (parameterType.IsStringArray())
{
var elementType = parameterDefinition.MarshallingInformation.ElementType;
if (elementType == UnmanagedType.LPStr || elementType == UnmanagedType.LPWStr)
{
return true;
}
}
return false;
}
}
// blittable, delegates and strings -- these last two have special marshalling we take care of
if (parameterType.TypeCode == PrimitiveTypeCode.Boolean || parameterType.IsBlittable() || parameterType.IsDelegate() || parameterType.IsString())
{
return true;
}
// we also support string[] since it's so common, by converting it to IntPtr[] in a try/finally
if (parameterType.IsStringArray())
{
return true;
}
// TODO: Support ICustomMarshaler
return false;
}
}
}