Skip to content

Commit 86e9f48

Browse files
authored
Update SslOverTdsStream (dotnet#541)
1 parent f2cbaf8 commit 86e9f48

File tree

5 files changed

+574
-218
lines changed

5 files changed

+574
-218
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@
261261
<Compile Include="Microsoft\Data\SqlClient\SqlDiagnosticListener.NetStandard.cs" />
262262
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetStandard.cs" />
263263
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetStandard.cs" />
264+
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetStandard.cs" />
264265
</ItemGroup>
265266
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetFramework)' != 'netstandard2.0'">
266267
<Compile Include="..\..\src\Microsoft\Data\SqlClient\AlwaysEncryptedAttestationException.cs">
@@ -289,6 +290,7 @@
289290
<Compile Include="Microsoft\Data\SqlClient\SqlDiagnosticListener.NetCoreApp.cs" />
290291
<Compile Include="Microsoft\Data\SqlClient\SqlDelegatedTransaction.NetCoreApp.cs" />
291292
<Compile Include="Microsoft\Data\SqlClient\TdsParser.NetCoreApp.cs" />
293+
<Compile Include="Microsoft\Data\SqlClient\SNI\SslOverTdsStream.NetCoreApp.cs" />
292294
</ItemGroup>
293295
<ItemGroup Condition="'$(OSGroup)' != 'AnyOS' AND '$(TargetGroup)' == 'netcoreapp' AND '$(BuildSimulator)' == 'true'">
294296
<Compile Include="Microsoft\Data\SqlClient\SimulatorEnclaveProvider.NetCoreApp.cs" />
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Buffers;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
10+
namespace Microsoft.Data.SqlClient.SNI
11+
{
12+
internal sealed partial class SslOverTdsStream
13+
{
14+
public override int Read(byte[] buffer, int offset, int count)
15+
{
16+
return Read(buffer.AsSpan(offset, count));
17+
}
18+
19+
public override void Write(byte[] buffer, int offset, int count)
20+
{
21+
Write(buffer.AsSpan(offset, count));
22+
}
23+
24+
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
25+
{
26+
return ReadAsync(new Memory<byte>(buffer, offset, count), cancellationToken).AsTask();
27+
}
28+
29+
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
30+
{
31+
return WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
32+
}
33+
34+
public override int Read(Span<byte> buffer)
35+
{
36+
if (!_encapsulate)
37+
{
38+
return _stream.Read(buffer);
39+
}
40+
41+
using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.Read |SNI|INFO|SCOPE> reading encapsulated bytes"))
42+
{
43+
if (_packetBytes > 0)
44+
{
45+
// there are queued bytes from a previous packet available
46+
// work out how many of the remaining bytes we can consume
47+
int wantedCount = Math.Min(buffer.Length, _packetBytes);
48+
int readCount = _stream.Read(buffer.Slice(0, wantedCount));
49+
if (readCount == 0)
50+
{
51+
// 0 means the connection was closed, tell the caller
52+
return 0;
53+
}
54+
_packetBytes -= readCount;
55+
return readCount;
56+
}
57+
else
58+
{
59+
Span<byte> headerBytes = stackalloc byte[TdsEnums.HEADER_LEN];
60+
61+
// fetch the packet header to determine how long the packet is
62+
int headerBytesRead = 0;
63+
do
64+
{
65+
int headerBytesReadIteration = _stream.Read(headerBytes.Slice(headerBytesRead, TdsEnums.HEADER_LEN - headerBytesRead));
66+
if (headerBytesReadIteration == 0)
67+
{
68+
// 0 means the connection was closed, tell the caller
69+
return 0;
70+
}
71+
headerBytesRead += headerBytesReadIteration;
72+
} while (headerBytesRead < TdsEnums.HEADER_LEN);
73+
74+
// read the packet data size from the header and store it in case it is needed for a subsequent call
75+
_packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN;
76+
77+
// read as much from the packet as the caller can accept
78+
int packetBytesRead = _stream.Read(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)));
79+
_packetBytes -= packetBytesRead;
80+
return packetBytesRead;
81+
}
82+
}
83+
}
84+
85+
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
86+
{
87+
if (!_encapsulate)
88+
{
89+
int read;
90+
{
91+
ValueTask<int> readValueTask = _stream.ReadAsync(buffer, cancellationToken);
92+
if (readValueTask.IsCompletedSuccessfully)
93+
{
94+
read = readValueTask.Result;
95+
}
96+
else
97+
{
98+
read = await readValueTask.ConfigureAwait(false);
99+
}
100+
}
101+
return read;
102+
}
103+
using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.ReadAsync |SNI|INFO|SCOPE> reading encapsulated bytes"))
104+
{
105+
if (_packetBytes > 0)
106+
{
107+
// there are queued bytes from a previous packet available
108+
// work out how many of the remaining bytes we can consume
109+
int wantedCount = Math.Min(buffer.Length, _packetBytes);
110+
111+
int readCount;
112+
{
113+
ValueTask<int> remainderReadValueTask = _stream.ReadAsync(buffer.Slice(0, wantedCount), cancellationToken);
114+
if (remainderReadValueTask.IsCompletedSuccessfully)
115+
{
116+
readCount = remainderReadValueTask.Result;
117+
}
118+
else
119+
{
120+
readCount = await remainderReadValueTask.ConfigureAwait(false);
121+
}
122+
}
123+
if (readCount == 0)
124+
{
125+
// 0 means the connection was closed, tell the caller
126+
return 0;
127+
}
128+
_packetBytes -= readCount;
129+
return readCount;
130+
}
131+
else
132+
{
133+
byte[] headerBytes = ArrayPool<byte>.Shared.Rent(TdsEnums.HEADER_LEN);
134+
135+
// fetch the packet header to determine how long the packet is
136+
int headerBytesRead = 0;
137+
do
138+
{
139+
int headerBytesReadIteration;
140+
{
141+
ValueTask<int> headerReadValueTask = _stream.ReadAsync(headerBytes.AsMemory(headerBytesRead, (TdsEnums.HEADER_LEN - headerBytesRead)), cancellationToken);
142+
if (headerReadValueTask.IsCompletedSuccessfully)
143+
{
144+
headerBytesReadIteration = headerReadValueTask.Result;
145+
}
146+
else
147+
{
148+
headerBytesReadIteration = await headerReadValueTask.ConfigureAwait(false);
149+
}
150+
}
151+
if (headerBytesReadIteration == 0)
152+
{
153+
// 0 means the connection was closed, cleanup the rented array and then tell the caller
154+
ArrayPool<byte>.Shared.Return(headerBytes, clearArray: true);
155+
return 0;
156+
}
157+
headerBytesRead += headerBytesReadIteration;
158+
} while (headerBytesRead < TdsEnums.HEADER_LEN);
159+
160+
// read the packet data size from the header and store it in case it is needed for a subsequent call
161+
_packetBytes = ((headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET] << 8) | headerBytes[TdsEnums.HEADER_LEN_FIELD_OFFSET + 1]) - TdsEnums.HEADER_LEN;
162+
163+
ArrayPool<byte>.Shared.Return(headerBytes, clearArray: true);
164+
165+
// read as much from the packet as the caller can accept
166+
int packetBytesRead;
167+
{
168+
ValueTask<int> packetReadValueTask = _stream.ReadAsync(buffer.Slice(0, Math.Min(buffer.Length, _packetBytes)), cancellationToken);
169+
if (packetReadValueTask.IsCompletedSuccessfully)
170+
{
171+
packetBytesRead = packetReadValueTask.Result;
172+
}
173+
else
174+
{
175+
packetBytesRead = await packetReadValueTask.ConfigureAwait(false);
176+
}
177+
}
178+
_packetBytes -= packetBytesRead;
179+
return packetBytesRead;
180+
}
181+
}
182+
}
183+
184+
public override void Write(ReadOnlySpan<byte> buffer)
185+
{
186+
// During the SSL negotiation phase, SSL is tunnelled over TDS packet type 0x12. After
187+
// negotiation, the underlying socket only sees SSL frames.
188+
if (!_encapsulate)
189+
{
190+
_stream.Write(buffer);
191+
_stream.Flush();
192+
return;
193+
}
194+
195+
using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.Write |SNI|INFO|SCOPE> writing encapsulated bytes"))
196+
{
197+
ReadOnlySpan<byte> remaining = buffer;
198+
byte[] packetBuffer = null;
199+
try
200+
{
201+
while (remaining.Length > 0)
202+
{
203+
int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length);
204+
int packetLength = TdsEnums.HEADER_LEN + dataLength;
205+
206+
if (packetBuffer == null)
207+
{
208+
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
209+
}
210+
else if (packetBuffer.Length < packetLength)
211+
{
212+
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
213+
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
214+
}
215+
216+
SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength);
217+
218+
Span<byte> data = packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength);
219+
remaining.Slice(0, dataLength).CopyTo(data);
220+
221+
_stream.Write(packetBuffer.AsSpan(0, packetLength));
222+
_stream.Flush();
223+
224+
remaining = remaining.Slice(dataLength);
225+
}
226+
}
227+
finally
228+
{
229+
if (packetBuffer != null)
230+
{
231+
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
232+
}
233+
}
234+
}
235+
}
236+
237+
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
238+
{
239+
if (!_encapsulate)
240+
{
241+
{
242+
ValueTask valueTask = _stream.WriteAsync(buffer, cancellationToken);
243+
if (!valueTask.IsCompletedSuccessfully)
244+
{
245+
await valueTask.ConfigureAwait(false);
246+
}
247+
}
248+
Task flushTask = _stream.FlushAsync();
249+
if (flushTask.IsCompletedSuccessfully)
250+
{
251+
await flushTask.ConfigureAwait(false);
252+
}
253+
return;
254+
}
255+
256+
using (SNIEventScope.Create("<sc.SNI.SslOverTdsStream.WriteAsync |SNI|INFO|SCOPE> writing encapsulated bytes"))
257+
{
258+
ReadOnlyMemory<byte> remaining = buffer;
259+
byte[] packetBuffer = null;
260+
try
261+
{
262+
while (remaining.Length > 0)
263+
{
264+
int dataLength = Math.Min(PACKET_SIZE_WITHOUT_HEADER, remaining.Length);
265+
int packetLength = TdsEnums.HEADER_LEN + dataLength;
266+
267+
if (packetBuffer == null)
268+
{
269+
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
270+
}
271+
else if (packetBuffer.Length < packetLength)
272+
{
273+
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
274+
packetBuffer = ArrayPool<byte>.Shared.Rent(packetLength);
275+
}
276+
277+
SetupPreLoginPacketHeader(packetBuffer, dataLength, remaining.Length - dataLength);
278+
279+
remaining.Span.Slice(0, dataLength).CopyTo(packetBuffer.AsSpan(TdsEnums.HEADER_LEN, dataLength));
280+
281+
{
282+
ValueTask packetWriteValueTask = _stream.WriteAsync(new ReadOnlyMemory<byte>(packetBuffer, 0, packetLength), cancellationToken);
283+
if (!packetWriteValueTask.IsCompletedSuccessfully)
284+
{
285+
await packetWriteValueTask.ConfigureAwait(false);
286+
}
287+
}
288+
289+
await _stream.FlushAsync().ConfigureAwait(false);
290+
291+
292+
remaining = remaining.Slice(dataLength);
293+
}
294+
}
295+
finally
296+
{
297+
if (packetBuffer != null)
298+
{
299+
ArrayPool<byte>.Shared.Return(packetBuffer, clearArray: true);
300+
}
301+
}
302+
}
303+
}
304+
}
305+
}

0 commit comments

Comments
 (0)