using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.Versioning;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Fleck.Handlers;
using Fleck.Helpers;
[assembly: CompilationRelaxations(8)]
[assembly: RuntimeCompatibility(WrapNonExceptionThrows = true)]
[assembly: Debuggable(DebuggableAttribute.DebuggingModes.IgnoreSymbolStoreSequencePoints)]
[assembly: TargetFramework(".NETFramework,Version=v4.5", FrameworkDisplayName = ".NET Framework 4.5")]
[assembly: AssemblyCompany("statenjason")]
[assembly: AssemblyConfiguration("Release")]
[assembly: AssemblyCopyright("Copyright Jason Staten 2010-2018. All rights reserved.")]
[assembly: AssemblyDescription("C# WebSocket Implementation")]
[assembly: AssemblyFileVersion("1.2.0.0")]
[assembly: AssemblyInformationalVersion("1.2.0")]
[assembly: AssemblyProduct("Fleck")]
[assembly: AssemblyTitle("Fleck")]
[assembly: AssemblyVersion("1.2.0.0")]
namespace Fleck
{
public class ConnectionNotAvailableException : Exception
{
public ConnectionNotAvailableException()
{
}
public ConnectionNotAvailableException(string message)
: base(message)
{
}
public ConnectionNotAvailableException(string message, Exception innerException)
: base(message, innerException)
{
}
}
public enum LogLevel
{
Debug,
Info,
Warn,
Error
}
public class FleckLog
{
public static LogLevel Level = LogLevel.Info;
public static Action<LogLevel, string, Exception> LogAction = delegate(LogLevel level, string message, Exception ex)
{
if (level >= Level)
{
Console.WriteLine("{0} [{1}] {2} {3}", DateTime.Now, level, message, ex);
}
};
public static void Warn(string message, Exception ex = null)
{
LogAction(LogLevel.Warn, message, ex);
}
public static void Error(string message, Exception ex = null)
{
LogAction(LogLevel.Error, message, ex);
}
public static void Debug(string message, Exception ex = null)
{
LogAction(LogLevel.Debug, message, ex);
}
public static void Info(string message, Exception ex = null)
{
LogAction(LogLevel.Info, message, ex);
}
}
public enum FrameType : byte
{
Continuation = 0,
Text = 1,
Binary = 2,
Close = 8,
Ping = 9,
Pong = 10
}
public class HandlerFactory
{
public static IHandler BuildHandler(WebSocketHttpRequest request, Action<string> onMessage, Action onClose, Action<byte[]> onBinary, Action<byte[]> onPing, Action<byte[]> onPong)
{
switch (GetVersion(request))
{
case "76":
return Draft76Handler.Create(request, onMessage);
case "7":
case "8":
case "13":
return Hybi13Handler.Create(request, onMessage, onClose, onBinary, onPing, onPong);
case "policy-file-request":
return FlashSocketPolicyRequestHandler.Create(request);
default:
throw new WebSocketException(1003);
}
}
public static string GetVersion(WebSocketHttpRequest request)
{
if (request.Headers.TryGetValue("Sec-WebSocket-Version", out var value))
{
return value;
}
if (request.Headers.TryGetValue("Sec-WebSocket-Draft", out value))
{
return value;
}
if (request.Headers.ContainsKey("Sec-WebSocket-Key1"))
{
return "76";
}
if (request.Body != null && request.Body.ToLower().Contains("policy-file-request"))
{
return "policy-file-request";
}
return "75";
}
}
public interface IHandler
{
byte[] CreateHandshake(string subProtocol = null);
void Receive(IEnumerable<byte> data);
byte[] FrameText(string text);
byte[] FrameBinary(byte[] bytes);
byte[] FramePing(byte[] bytes);
byte[] FramePong(byte[] bytes);
byte[] FrameClose(int code);
}
public interface ISocket
{
bool Connected { get; }
string RemoteIpAddress { get; }
int RemotePort { get; }
Stream Stream { get; }
bool NoDelay { get; set; }
EndPoint LocalEndPoint { get; }
Task<ISocket> Accept(Action<ISocket> callback, Action<Exception> error);
Task Send(byte[] buffer, Action callback, Action<Exception> error);
Task<int> Receive(byte[] buffer, Action<int> callback, Action<Exception> error, int offset = 0);
Task Authenticate(X509Certificate2 certificate, SslProtocols enabledSslProtocols, Action callback, Action<Exception> error);
void Dispose();
void Close();
void Bind(EndPoint ipLocal);
void Listen(int backlog);
}
public interface IWebSocketConnection
{
Action OnOpen { get; set; }
Action OnClose { get; set; }
Action<string> OnMessage { get; set; }
Action<byte[]> OnBinary { get; set; }
Action<byte[]> OnPing { get; set; }
Action<byte[]> OnPong { get; set; }
Action<Exception> OnError { get; set; }
IWebSocketConnectionInfo ConnectionInfo { get; }
bool IsAvailable { get; }
Task Send(string message);
Task Send(byte[] message);
Task SendPing(byte[] message);
Task SendPong(byte[] message);
void Close();
void Close(int code);
}
public interface IWebSocketConnectionInfo
{
string SubProtocol { get; }
string Origin { get; }
string Host { get; }
string Path { get; }
string ClientIpAddress { get; }
int ClientPort { get; }
IDictionary<string, string> Cookies { get; }
IDictionary<string, string> Headers { get; }
Guid Id { get; }
string NegotiatedSubProtocol { get; }
}
public interface IWebSocketServer : IDisposable
{
void Start(Action<IWebSocketConnection> config);
}
public static class IntExtensions
{
public static byte[] ToBigEndianBytes<T>(this int source)
{
Type typeFromHandle = typeof(T);
byte[] bytes;
if (typeFromHandle == typeof(ushort))
{
bytes = BitConverter.GetBytes((ushort)source);
}
else if (typeFromHandle == typeof(ulong))
{
bytes = BitConverter.GetBytes((ulong)source);
}
else
{
if (!(typeFromHandle == typeof(int)))
{
throw new InvalidCastException("Cannot be cast to T");
}
bytes = BitConverter.GetBytes(source);
}
if (BitConverter.IsLittleEndian)
{
Array.Reverse((Array)bytes);
}
return bytes;
}
public static int ToLittleEndianInt(this byte[] source)
{
if (BitConverter.IsLittleEndian)
{
Array.Reverse((Array)source);
}
if (source.Length == 2)
{
return BitConverter.ToUInt16(source, 0);
}
if (source.Length == 8)
{
return (int)BitConverter.ToUInt64(source, 0);
}
throw new ArgumentException("Unsupported Size");
}
}
public class QueuedStream : Stream
{
private class WriteData
{
public readonly byte[] Buffer;
public readonly int Offset;
public readonly int Count;
public readonly AsyncCallback Callback;
public readonly object State;
public readonly QueuedWriteResult AsyncResult;
public WriteData(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
Buffer = buffer;
Offset = offset;
Count = count;
Callback = callback;
State = state;
AsyncResult = new QueuedWriteResult(state);
}
}
private class QueuedWriteResult : IAsyncResult
{
private readonly object _state;
public Exception Exception { get; set; }
public IAsyncResult ActualResult { get; set; }
public object AsyncState => _state;
public WaitHandle AsyncWaitHandle
{
get
{
throw new NotSupportedException("Queued write operations do not support wait handle.");
}
}
public bool CompletedSynchronously => false;
public bool IsCompleted
{
get
{
if (ActualResult != null)
{
return ActualResult.IsCompleted;
}
return false;
}
}
public QueuedWriteResult(object state)
{
_state = state;
}
}
private readonly Stream _stream;
private readonly Queue<WriteData> _queue = new Queue<WriteData>();
private int _pendingWrite;
private bool _disposed;
public override bool CanRead => _stream.CanRead;
public override bool CanSeek => _stream.CanSeek;
public override bool CanWrite => _stream.CanWrite;
public override long Length => _stream.Length;
public override long Position
{
get
{
return _stream.Position;
}
set
{
_stream.Position = value;
}
}
public QueuedStream(Stream stream)
{
_stream = stream;
}
public override int Read(byte[] buffer, int offset, int count)
{
return _stream.Read(buffer, offset, count);
}
public override long Seek(long offset, SeekOrigin origin)
{
return _stream.Seek(offset, origin);
}
public override void SetLength(long value)
{
_stream.SetLength(value);
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException("QueuedStream does not support synchronous write operations yet.");
}
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return _stream.BeginRead(buffer, offset, count, callback, state);
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
lock (_queue)
{
WriteData writeData = new WriteData(buffer, offset, count, callback, state);
if (_pendingWrite > 0)
{
_queue.Enqueue(writeData);
return writeData.AsyncResult;
}
return BeginWriteInternal(buffer, offset, count, callback, state, writeData);
}
}
public override int EndRead(IAsyncResult asyncResult)
{
return _stream.EndRead(asyncResult);
}
public override void EndWrite(IAsyncResult asyncResult)
{
if (asyncResult is QueuedWriteResult)
{
QueuedWriteResult queuedWriteResult = asyncResult as QueuedWriteResult;
if (queuedWriteResult.Exception != null)
{
throw queuedWriteResult.Exception;
}
if (queuedWriteResult.ActualResult == null)
{
throw new NotSupportedException("QueuedStream does not support synchronous write operations. Please wait for callback to be invoked before calling EndWrite.");
}
return;
}
throw new ArgumentException();
}
public override void Flush()
{
_stream.Flush();
}
public override void Close()
{
_stream.Close();
}
protected override void Dispose(bool disposing)
{
if (!_disposed)
{
if (disposing)
{
_stream.Dispose();
}
_disposed = true;
}
base.Dispose(disposing);
}
private IAsyncResult BeginWriteInternal(byte[] buffer, int offset, int count, AsyncCallback callback, object state, WriteData queued)
{
_pendingWrite++;
IAsyncResult actualResult = _stream.BeginWrite(buffer, offset, count, delegate(IAsyncResult ar)
{
queued.AsyncResult.ActualResult = ar;
try
{
_stream.EndWrite(ar);
}
catch (Exception exception)
{
queued.AsyncResult.Exception = exception;
}
lock (_queue)
{
_pendingWrite--;
while (_queue.Count > 0)
{
WriteData writeData = _queue.Dequeue();
try
{
writeData.AsyncResult.ActualResult = BeginWriteInternal(writeData.Buffer, writeData.Offset, writeData.Count, writeData.Callback, writeData.State, writeData);
}
catch (Exception exception2)
{
_pendingWrite--;
writeData.AsyncResult.Exception = exception2;
writeData.Callback(writeData.AsyncResult);
continue;
}
break;
}
callback(queued.AsyncResult);
}
}, state);
queued.AsyncResult.ActualResult = actualResult;
return queued.AsyncResult;
}
}
public class ReadState
{
public List<byte> Data { get; private set; }
public FrameType? FrameType { get; set; }
public ReadState()
{
Data = new List<byte>();
}
public void Clear()
{
Data.Clear();
FrameType = null;
}
}
public class RequestParser
{
private const string pattern = "^(?<method>[^\\s]+)\\s(?<path>[^\\s]+)\\sHTTP\\/1\\.1\\r\\n((?<field_name>[^:\\r\\n]+):(?([^\\r\\n])\\s)*(?<field_value>[^\\r\\n]*)\\r\\n)+\\r\\n(?<body>.+)?";
private const string FlashSocketPolicyRequestPattern = "^[<]policy-file-request\\s*[/][>]";
private static readonly Regex _regex = new Regex("^(?<method>[^\\s]+)\\s(?<path>[^\\s]+)\\sHTTP\\/1\\.1\\r\\n((?<field_name>[^:\\r\\n]+):(?([^\\r\\n])\\s)*(?<field_value>[^\\r\\n]*)\\r\\n)+\\r\\n(?<body>.+)?", RegexOptions.IgnoreCase | RegexOptions.Compiled);
private static readonly Regex _FlashSocketPolicyRequestRegex = new Regex("^[<]policy-file-request\\s*[/][>]", RegexOptions.IgnoreCase | RegexOptions.Compiled);
public static WebSocketHttpRequest Parse(byte[] bytes)
{
return Parse(bytes, "ws");
}
public static WebSocketHttpRequest Parse(byte[] bytes, string scheme)
{
string @string = Encoding.UTF8.GetString(bytes);
Match match = _regex.Match(@string);
if (!match.Success)
{
match = _FlashSocketPolicyRequestRegex.Match(@string);
if (match.Success)
{
return new WebSocketHttpRequest
{
Body = @string,
Bytes = bytes
};
}
return null;
}
WebSocketHttpRequest webSocketHttpRequest = new WebSocketHttpRequest
{
Method = match.Groups["method"].Value,
Path = match.Groups["path"].Value,
Body = match.Groups["body"].Value,
Bytes = bytes,
Scheme = scheme
};
CaptureCollection captures = match.Groups["field_name"].Captures;
CaptureCollection captures2 = match.Groups["field_value"].Captures;
for (int i = 0; i < captures.Count; i++)
{
string key = captures[i].ToString();
string value = captures2[i].ToString();
webSocketHttpRequest.Headers[key] = value;
}
return webSocketHttpRequest;
}
}
public class SocketWrapper : ISocket
{
public const uint KeepAliveInterval = 60000u;
public const uint RetryInterval = 10000u;
private readonly Socket _socket;
private Stream _stream;
private CancellationTokenSource _tokenSource;
private TaskFactory _taskFactory;
public string RemoteIpAddress
{
get
{
if (!(_socket.RemoteEndPoint is IPEndPoint iPEndPoint))
{
return null;
}
return iPEndPoint.Address.ToString();
}
}
public int RemotePort
{
get
{
if (!(_socket.RemoteEndPoint is IPEndPoint iPEndPoint))
{
return -1;
}
return iPEndPoint.Port;
}
}
public bool Connected => _socket.Connected;
public Stream Stream => _stream;
public bool NoDelay
{
get
{
return _socket.NoDelay;
}
set
{
_socket.NoDelay = value;
}
}
public EndPoint LocalEndPoint => _socket.LocalEndPoint;
public void SetKeepAlive(Socket socket, uint keepAliveInterval, uint retryInterval)
{
int num = 4;
byte[] array = new byte[num * 3];
Array.Copy(BitConverter.GetBytes(1u), 0, array, 0, num);
Array.Copy(BitConverter.GetBytes(keepAliveInterval), 0, array, num, num);
Array.Copy(BitConverter.GetBytes(retryInterval), 0, array, num * 2, num);
socket.IOControl(IOControlCode.KeepAliveValues, array, null);
}
public SocketWrapper(Socket socket)
{
_tokenSource = new CancellationTokenSource();
_taskFactory = new TaskFactory(_tokenSource.Token);
_socket = socket;
if (_socket.Connected)
{
_stream = new NetworkStream(_socket);
}
if (FleckRuntime.IsRunningOnWindows())
{
SetKeepAlive(socket, 60000u, 10000u);
}
}
public Task Authenticate(X509Certificate2 certificate, SslProtocols enabledSslProtocols, Action callback, Action<Exception> error)
{
SslStream ssl = new SslStream(_stream, leaveInnerStreamOpen: false);
_stream = new QueuedStream(ssl);
Func<AsyncCallback, object, IAsyncResult> beginMethod = (AsyncCallback cb, object s) => ssl.BeginAuthenticateAsServer(certificate, clientCertificateRequired: false, enabledSslProtocols, checkCertificateRevocation: false, cb, s);
Task task = Task.Factory.FromAsync(beginMethod, ssl.EndAuthenticateAsServer, null);
task.ContinueWith(delegate
{
callback();
}, TaskContinuationOptions.NotOnFaulted).ContinueWith(delegate(Task t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
task.ContinueWith(delegate(Task t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
return task;
}
public void Listen(int backlog)
{
_socket.Listen(backlog);
}
public void Bind(EndPoint endPoint)
{
_socket.Bind(endPoint);
}
public Task<int> Receive(byte[] buffer, Action<int> callback, Action<Exception> error, int offset)
{
try
{
Func<AsyncCallback, object, IAsyncResult> beginMethod = (AsyncCallback cb, object s) => _stream.BeginRead(buffer, offset, buffer.Length, cb, s);
Task<int> task = Task.Factory.FromAsync(beginMethod, (Func<IAsyncResult, int>)_stream.EndRead, (object?)null);
task.ContinueWith(delegate(Task<int> t)
{
callback(t.Result);
}, TaskContinuationOptions.NotOnFaulted).ContinueWith(delegate(Task t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
task.ContinueWith(delegate(Task<int> t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
return task;
}
catch (Exception obj)
{
error(obj);
return null;
}
}
public Task<ISocket> Accept(Action<ISocket> callback, Action<Exception> error)
{
Func<IAsyncResult, ISocket> endMethod = (IAsyncResult r) => (!_tokenSource.Token.IsCancellationRequested) ? new SocketWrapper(_socket.EndAccept(r)) : null;
Task<ISocket> task = _taskFactory.FromAsync(_socket.BeginAccept, endMethod, null);
task.ContinueWith(delegate(Task<ISocket> t)
{
callback(t.Result);
}, TaskContinuationOptions.OnlyOnRanToCompletion).ContinueWith(delegate(Task t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
task.ContinueWith(delegate(Task<ISocket> t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
return task;
}
public void Dispose()
{
_tokenSource.Cancel();
if (_stream != null)
{
_stream.Dispose();
}
if (_socket != null)
{
_socket.Dispose();
}
}
public void Close()
{
_tokenSource.Cancel();
if (_stream != null)
{
_stream.Close();
}
if (_socket != null)
{
_socket.Close();
}
}
public int EndSend(IAsyncResult asyncResult)
{
_stream.EndWrite(asyncResult);
return 0;
}
public Task Send(byte[] buffer, Action callback, Action<Exception> error)
{
if (_tokenSource.IsCancellationRequested)
{
return null;
}
try
{
Func<AsyncCallback, object, IAsyncResult> beginMethod = (AsyncCallback cb, object s) => _stream.BeginWrite(buffer, 0, buffer.Length, cb, s);
Task task = Task.Factory.FromAsync(beginMethod, _stream.EndWrite, null);
task.ContinueWith(delegate
{
callback();
}, TaskContinuationOptions.NotOnFaulted).ContinueWith(delegate(Task t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
task.ContinueWith(delegate(Task t)
{
error(t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
return task;
}
catch (Exception obj)
{
error(obj);
return null;
}
}
}
public class SubProtocolNegotiationFailureException : Exception
{
public SubProtocolNegotiationFailureException()
{
}
public SubProtocolNegotiationFailureException(string message)
: base(message)
{
}
public SubProtocolNegotiationFailureException(string message, Exception innerException)
: base(message, innerException)
{
}
}
public static class SubProtocolNegotiator
{
public static string Negotiate(IEnumerable<string> server, IEnumerable<string> client)
{
if (!server.Any() || !client.Any())
{
return null;
}
IEnumerable<string> source = client.Intersect(server);
if (!source.Any())
{
throw new SubProtocolNegotiationFailureException("Unable to negotiate a subprotocol");
}
return source.First();
}
}
public class WebSocketConnection : IWebSocketConnection
{
private readonly Action<IWebSocketConnection> _initialize;
private readonly Func<WebSocketHttpRequest, IHandler> _handlerFactory;
private readonly Func<IEnumerable<string>, string> _negotiateSubProtocol;
private readonly Func<byte[], WebSocketHttpRequest> _parseRequest;
private bool _closing;
private bool _closed;
private const int ReadSize = 4096;
public ISocket Socket { get; set; }
public IHandler Handler { get; set; }
public Action OnOpen { get; set; }
public Action OnClose { get; set; }
public Action<string> OnMessage { get; set; }
public Action<byte[]> OnBinary { get; set; }
public Action<byte[]> OnPing { get; set; }
public Action<byte[]> OnPong { get; set; }
public Action<Exception> OnError { get; set; }
public IWebSocketConnectionInfo ConnectionInfo { get; private set; }
public bool IsAvailable
{
get
{
if (!_closing && !_closed)
{
return Socket.Connected;
}
return false;
}
}
public WebSocketConnection(ISocket socket, Action<IWebSocketConnection> initialize, Func<byte[], WebSocketHttpRequest> parseRequest, Func<WebSocketHttpRequest, IHandler> handlerFactory, Func<IEnumerable<string>, string> negotiateSubProtocol)
{
Socket = socket;
OnOpen = delegate
{
};
OnClose = delegate
{
};
OnMessage = delegate
{
};
OnBinary = delegate
{
};
OnPing = delegate(byte[] x)
{
SendPong(x);
};
OnPong = delegate
{
};
OnError = delegate
{
};
_initialize = initialize;
_handlerFactory = handlerFactory;
_parseRequest = parseRequest;
_negotiateSubProtocol = negotiateSubProtocol;
}
public Task Send(string message)
{
return Send(message, Handler.FrameText);
}
public Task Send(byte[] message)
{
return Send(message, Handler.FrameBinary);
}
public Task SendPing(byte[] message)
{
return Send(message, Handler.FramePing);
}
public Task SendPong(byte[] message)
{
return Send(message, Handler.FramePong);
}
private Task Send<T>(T message, Func<T, byte[]> createFrame)
{
if (Handler == null)
{
throw new InvalidOperationException("Cannot send before handshake");
}
if (!IsAvailable)
{
FleckLog.Warn("Data sent while closing or after close. Ignoring.");
TaskCompletionSource<object> taskCompletionSource = new TaskCompletionSource<object>();
taskCompletionSource.SetException(new ConnectionNotAvailableException("Data sent while closing or after close. Ignoring."));
return taskCompletionSource.Task;
}
byte[] bytes = createFrame(message);
return SendBytes(bytes);
}
public void StartReceiving()
{
List<byte> data = new List<byte>(4096);
byte[] buffer = new byte[4096];
Read(data, buffer);
}
public void Close()
{
Close(1000);
}
public void Close(int code)
{
if (!IsAvailable)
{
return;
}
_closing = true;
if (Handler == null)
{
CloseSocket();
return;
}
byte[] array = Handler.FrameClose(code);
if (array.Length == 0)
{
CloseSocket();
}
else
{
SendBytes(array, CloseSocket);
}
}
public void CreateHandler(IEnumerable<byte> data)
{
WebSocketHttpRequest webSocketHttpRequest = _parseRequest(data.ToArray());
if (webSocketHttpRequest != null)
{
Handler = _handlerFactory(webSocketHttpRequest);
if (Handler != null)
{
string text = _negotiateSubProtocol(webSocketHttpRequest.SubProtocols);
ConnectionInfo = WebSocketConnectionInfo.Create(webSocketHttpRequest, Socket.RemoteIpAddress, Socket.RemotePort, text);
_initialize(this);
byte[] bytes = Handler.CreateHandshake(text);
SendBytes(bytes, OnOpen);
}
}
}
private void Read(List<byte> data, byte[] buffer)
{
if (!IsAvailable)
{
return;
}
Socket.Receive(buffer, delegate(int r)
{
if (r <= 0)
{
FleckLog.Debug("0 bytes read. Closing.");
CloseSocket();
}
else
{
FleckLog.Debug(r + " bytes read");
IEnumerable<byte> enumerable = buffer.Take(r);
if (Handler != null)
{
Handler.Receive(enumerable);
}
else
{
data.AddRange(enumerable);
CreateHandler(data);
}
Read(data, buffer);
}
}, HandleReadError);
}
private void HandleReadError(Exception e)
{
if (e is AggregateException)
{
AggregateException ex = e as AggregateException;
HandleReadError(ex.InnerException);
return;
}
if (e is ObjectDisposedException)
{
FleckLog.Debug("Swallowing ObjectDisposedException", e);
return;
}
OnError(e);
if (e is WebSocketException)
{
FleckLog.Debug("Error while reading", e);
Close(((WebSocketException)e).StatusCode);
}
else if (e is SubProtocolNegotiationFailureException)
{
FleckLog.Debug(e.Message);
Close(1002);
}
else if (e is IOException)
{
FleckLog.Debug("Error while reading", e);
Close(1006);
}
else
{
FleckLog.Error("Application Error", e);
Close(1011);
}
}
private Task SendBytes(byte[] bytes, Action callback = null)
{
return Socket.Send(bytes, delegate
{
FleckLog.Debug("Sent " + bytes.Length + " bytes");
if (callback != null)
{
callback();
}
}, delegate(Exception e)
{
if (e is IOException)
{
FleckLog.Debug("Failed to send. Disconnecting.", e);
}
else
{
FleckLog.Info("Failed to send. Disconnecting.", e);
}
CloseSocket();
});
}
private void CloseSocket()
{
_closing = true;
OnClose();
_closed = true;
Socket.Close();
Socket.Dispose();
_closing = false;
}
}
public class WebSocketConnectionInfo : IWebSocketConnectionInfo
{
private const string CookiePattern = "((;)*(\\s)*(?<cookie_name>[^=]+)=(?<cookie_value>[^\\;]+))+";
private static readonly Regex CookieRegex = new Regex("((;)*(\\s)*(?<cookie_name>[^=]+)=(?<cookie_value>[^\\;]+))+", RegexOptions.Compiled);
public string NegotiatedSubProtocol { get; private set; }
public string SubProtocol { get; private set; }
public string Origin { get; private set; }
public string Host { get; private set; }
public string Path { get; private set; }
public string ClientIpAddress { get; set; }
public int ClientPort { get; set; }
public Guid Id { get; set; }
public IDictionary<string, string> Cookies { get; private set; }
public IDictionary<string, string> Headers { get; private set; }
public static WebSocketConnectionInfo Create(WebSocketHttpRequest request, string clientIp, int clientPort, string negotiatedSubprotocol)
{
WebSocketConnectionInfo webSocketConnectionInfo = new WebSocketConnectionInfo
{
Origin = (request["Origin"] ?? request["Sec-WebSocket-Origin"]),
Host = request["Host"],
SubProtocol = request["Sec-WebSocket-Protocol"],
Path = request.Path,
ClientIpAddress = clientIp,
ClientPort = clientPort,
NegotiatedSubProtocol = negotiatedSubprotocol,
Headers = new Dictionary<string, string>(request.Headers, StringComparer.InvariantCultureIgnoreCase)
};
string text = request["Cookie"];
if (text != null)
{
Match match = CookieRegex.Match(text);
CaptureCollection captures = match.Groups["cookie_name"].Captures;
CaptureCollection captures2 = match.Groups["cookie_value"].Captures;
for (int i = 0; i < captures.Count; i++)
{
string key = captures[i].ToString();
string value = captures2[i].ToString();
webSocketConnectionInfo.Cookies[key] = value;
}
}
return webSocketConnectionInfo;
}
private WebSocketConnectionInfo()
{
Cookies = new Dictionary<string, string>();
Id = Guid.NewGuid();
}
}
public class WebSocketException : Exception
{
public ushort StatusCode { get; private set; }
public WebSocketException(ushort statusCode)
{
StatusCode = statusCode;
}
public WebSocketException(ushort statusCode, string message)
: base(message)
{
StatusCode = statusCode;
}
public WebSocketException(ushort statusCode, string message, Exception innerException)
: base(message, innerException)
{
StatusCode = statusCode;
}
}
public class WebSocketHttpRequest
{
private readonly IDictionary<string, string> _headers = new Dictionary<string, string>(StringComparer.InvariantCultureIgnoreCase);
public string Method { get; set; }
public string Path { get; set; }
public string Body { get; set; }
public string Scheme { get; set; }
public byte[] Bytes { get; set; }
public string this[string name]
{
get
{
if (!_headers.TryGetValue(name, out var value))
{
return null;
}
return value;
}
}
public IDictionary<string, string> Headers => _headers;
public string[] SubProtocols
{
get
{
if (!_headers.TryGetValue("Sec-WebSocket-Protocol", out var value))
{
return new string[0];
}
return value.Split(new char[2] { ',', ' ' }, StringSplitOptions.RemoveEmptyEntries);
}
}
}
public class WebSocketServer : IWebSocketServer, IDisposable
{
private readonly string _scheme;
private readonly IPAddress _locationIP;
private Action<IWebSocketConnection> _config;
public ISocket ListenerSocket { get; set; }
public string Location { get; private set; }
public bool SupportDualStack { get; }
public int Port { get; private set; }
public X509Certificate2 Certificate { get; set; }
public SslProtocols EnabledSslProtocols { get; set; }
public IEnumerable<string> SupportedSubProtocols { get; set; }
public bool RestartAfterListenError { get; set; }
public bool IsSecure
{
get
{
if (_scheme == "wss")
{
return Certificate != null;
}
return false;
}
}
public WebSocketServer(string location, bool supportDualStack = true)
{
Uri uri = new Uri(location);
Port = uri.Port;
Location = location;
SupportDualStack = supportDualStack;
_locationIP = ParseIPAddress(uri);
_scheme = uri.Scheme;
Socket socket = new Socket(_locationIP.AddressFamily, SocketType.Stream, ProtocolType.IP);
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, 1);
if (SupportDualStack && !FleckRuntime.IsRunningOnMono() && FleckRuntime.IsRunningOnWindows())
{
socket.SetSocketOption(SocketOptionLevel.IPv6, SocketOptionName.IPv6Only, optionValue: false);
}
ListenerSocket = new SocketWrapper(socket);
SupportedSubProtocols = new string[0];
}
public void Dispose()
{
ListenerSocket.Dispose();
}
private IPAddress ParseIPAddress(Uri uri)
{
string host = uri.Host;
if (host == "0.0.0.0")
{
return IPAddress.Any;
}
if (host == "[0000:0000:0000:0000:0000:0000:0000:0000]")
{
return IPAddress.IPv6Any;
}
try
{
return IPAddress.Parse(host);
}
catch (Exception innerException)
{
throw new FormatException("Failed to parse the IP address part of the location. Please make sure you specify a valid IP address. Use 0.0.0.0 or [::] to listen on all interfaces.", innerException);
}
}
public void Start(Action<IWebSocketConnection> config)
{
IPEndPoint ipLocal = new IPEndPoint(_locationIP, Port);
ListenerSocket.Bind(ipLocal);
ListenerSocket.Listen(100);
Port = ((IPEndPoint)ListenerSocket.LocalEndPoint).Port;
FleckLog.Info($"Server started at {Location} (actual port {Port})");
if (_scheme == "wss")
{
if (Certificate == null)
{
FleckLog.Error("Scheme cannot be 'wss' without a Certificate");
return;
}
if (EnabledSslProtocols == SslProtocols.None)
{
EnabledSslProtocols = SslProtocols.Tls;
FleckLog.Debug("Using default TLS 1.0 security protocol.");
}
}
ListenForClients();
_config = config;
}
private void ListenForClients()
{
ListenerSocket.Accept(OnClientConnect, delegate(Exception e)
{
FleckLog.Error("Listener socket is closed", e);
if (RestartAfterListenError)
{
FleckLog.Info("Listener socket restarting");
try
{
ListenerSocket.Dispose();
Socket socket = new Socket(_locationIP.AddressFamily, SocketType.Stream, ProtocolType.IP);
ListenerSocket = new SocketWrapper(socket);
Start(_config);
FleckLog.Info("Listener socket restarted");
}
catch (Exception ex)
{
FleckLog.Error("Listener could not be restarted", ex);
}
}
});
}
private void OnClientConnect(ISocket clientSocket)
{
if (clientSocket == null)
{
return;
}
FleckLog.Debug($"Client connected from {clientSocket.RemoteIpAddress}:{clientSocket.RemotePort.ToString()}");
ListenForClients();
WebSocketConnection connection = null;
connection = new WebSocketConnection(clientSocket, _config, (byte[] bytes) => RequestParser.Parse(bytes, _scheme), (WebSocketHttpRequest r) => HandlerFactory.BuildHandler(r, delegate(string s)
{
connection.OnMessage(s);
}, connection.Close, delegate(byte[] b)
{
connection.OnBinary(b);
}, delegate(byte[] b)
{
connection.OnPing(b);
}, delegate(byte[] b)
{
connection.OnPong(b);
}), (IEnumerable<string> s) => SubProtocolNegotiator.Negotiate(SupportedSubProtocols, s));
if (IsSecure)
{
FleckLog.Debug("Authenticating Secure Connection");
clientSocket.Authenticate(Certificate, EnabledSslProtocols, connection.StartReceiving, delegate(Exception e)
{
FleckLog.Warn("Failed to Authenticate", e);
});
}
else
{
connection.StartReceiving();
}
}
}
public static class WebSocketStatusCodes
{
public const ushort NormalClosure = 1000;
public const ushort GoingAway = 1001;
public const ushort ProtocolError = 1002;
public const ushort UnsupportedDataType = 1003;
public const ushort NoStatusReceived = 1005;
public const ushort AbnormalClosure = 1006;
public const ushort InvalidFramePayloadData = 1007;
public const ushort PolicyViolation = 1008;
public const ushort MessageTooBig = 1009;
public const ushort MandatoryExt = 1010;
public const ushort InternalServerError = 1011;
public const ushort TLSHandshake = 1015;
public const ushort ApplicationError = 3000;
public static ushort[] ValidCloseCodes = new ushort[9] { 1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011 };
}
}
namespace Fleck.Helpers
{
internal static class FleckRuntime
{
public static bool IsRunningOnMono()
{
return Type.GetType("Mono.Runtime") != null;
}
public static bool IsRunningOnWindows()
{
return true;
}
}
}
namespace Fleck.Handlers
{
public class ComposableHandler : IHandler
{
public Func<string, byte[]> Handshake = (string s) => new byte[0];
public Func<string, byte[]> TextFrame = (string x) => new byte[0];
public Func<byte[], byte[]> BinaryFrame = (byte[] x) => new byte[0];
public Action<List<byte>> ReceiveData = delegate
{
};
public Func<byte[], byte[]> PingFrame = (byte[] i) => new byte[0];
public Func<byte[], byte[]> PongFrame = (byte[] i) => new byte[0];
public Func<int, byte[]> CloseFrame = (int i) => new byte[0];
private readonly List<byte> _data = new List<byte>();
public byte[] CreateHandshake(string subProtocol = null)
{
return Handshake(subProtocol);
}
public void Receive(IEnumerable<byte> data)
{
_data.AddRange(data);
ReceiveData(_data);
}
public byte[] FrameText(string text)
{
return TextFrame(text);
}
public byte[] FrameBinary(byte[] bytes)
{
return BinaryFrame(bytes);
}
public byte[] FramePing(byte[] bytes)
{
return PingFrame(bytes);
}
public byte[] FramePong(byte[] bytes)
{
return PongFrame(bytes);
}
public byte[] FrameClose(int code)
{
return CloseFrame(code);
}
}
public static class Draft76Handler
{
private const byte End = byte.MaxValue;
private const byte Start = 0;
private const int MaxSize = 5242880;
public static IHandler Create(WebSocketHttpRequest request, Action<string> onMessage)
{
return new ComposableHandler
{
TextFrame = FrameText,
Handshake = (string sub) => Handshake(request, sub),
ReceiveData = delegate(List<byte> data)
{
ReceiveData(onMessage, data);
}
};
}
public static void ReceiveData(Action<string> onMessage, List<byte> data)
{
while (data.Count > 0)
{
if (data[0] != 0)
{
throw new WebSocketException(1007);
}
int num = data.IndexOf(byte.MaxValue);
if (num < 0)
{
break;
}
if (num > 5242880)
{
throw new WebSocketException(1009);
}
byte[] bytes = data.Skip(1).Take(num - 1).ToArray();
data.RemoveRange(0, num + 1);
string @string = Encoding.UTF8.GetString(bytes);
onMessage(@string);
}
}
public static byte[] FrameText(string data)
{
byte[] bytes = Encoding.UTF8.GetBytes(data);
byte[] array = new byte[bytes.Length + 2];
array[0] = 0;
array[^1] = byte.MaxValue;
Array.Copy(bytes, 0, array, 1, bytes.Length);
return array;
}
public static byte[] Handshake(WebSocketHttpRequest request, string subProtocol)
{
FleckLog.Debug("Building Draft76 Response");
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.Append("HTTP/1.1 101 WebSocket Protocol Handshake\r\n");
stringBuilder.Append("Upgrade: WebSocket\r\n");
stringBuilder.Append("Connection: Upgrade\r\n");
stringBuilder.AppendFormat("Sec-WebSocket-Origin: {0}\r\n", request["Origin"]);
stringBuilder.AppendFormat("Sec-WebSocket-Location: {0}://{1}{2}\r\n", request.Scheme, request["Host"], request.Path);
if (subProtocol != null)
{
stringBuilder.AppendFormat("Sec-WebSocket-Protocol: {0}\r\n", subProtocol);
}
stringBuilder.Append("\r\n");
string key = request["Sec-WebSocket-Key1"];
string key2 = request["Sec-WebSocket-Key2"];
ArraySegment<byte> challenge = new ArraySegment<byte>(request.Bytes, request.Bytes.Length - 8, 8);
byte[] array = CalculateAnswerBytes(key, key2, challenge);
byte[] array2 = Encoding.ASCII.GetBytes(stringBuilder.ToString());
int num = array2.Length;
Array.Resize(ref array2, num + array.Length);
Array.Copy(array, 0, array2, num, array.Length);
return array2;
}
public static byte[] CalculateAnswerBytes(string key1, string key2, ArraySegment<byte> challenge)
{
byte[] sourceArray = ParseKey(key1);
byte[] sourceArray2 = ParseKey(key2);
byte[] array = new byte[16];
Array.Copy(sourceArray, 0, array, 0, 4);
Array.Copy(sourceArray2, 0, array, 4, 4);
Array.Copy(challenge.Array, challenge.Offset, array, 8, 8);
return MD5.Create().ComputeHash(array);
}
private static byte[] ParseKey(string key)
{
int num = key.Count((char x) => x == ' ');
byte[] bytes = BitConverter.GetBytes((int)(long.Parse(new string(key.Where(char.IsDigit).ToArray())) / num));
if (BitConverter.IsLittleEndian)
{
Array.Reverse((Array)bytes);
}
return bytes;
}
}
public class FlashSocketPolicyRequestHandler
{
public static string PolicyResponse = "<?xml version=\"1.0\"?>\n<cross-domain-policy>\n <allow-access-from domain=\"*\" to-ports=\"*\"/>\n <site-control permitted-cross-domain-policies=\"all\"/>\n</cross-domain-policy>\n\0";
public static IHandler Create(WebSocketHttpRequest request)
{
return new ComposableHandler
{
Handshake = (string sub) => Handshake(request, sub)
};
}
public static byte[] Handshake(WebSocketHttpRequest request, string subProtocol)
{
FleckLog.Debug("Building Flash Socket Policy Response");
return Encoding.UTF8.GetBytes(PolicyResponse);
}
}
public static class Hybi13Handler
{
private const string WebSocketResponseGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
public static IHandler Create(WebSocketHttpRequest request, Action<string> onMessage, Action onClose, Action<byte[]> onBinary, Action<byte[]> onPing, Action<byte[]> onPong)
{
ReadState readState = new ReadState();
return new ComposableHandler
{
Handshake = (string sub) => BuildHandshake(request, sub),
TextFrame = (string s) => FrameData(Encoding.UTF8.GetBytes(s), FrameType.Text),
BinaryFrame = (byte[] s) => FrameData(s, FrameType.Binary),
PingFrame = (byte[] s) => FrameData(s, FrameType.Ping),
PongFrame = (byte[] s) => FrameData(s, FrameType.Pong),
CloseFrame = (int i) => FrameData(i.ToBigEndianBytes<ushort>(), FrameType.Close),
ReceiveData = delegate(List<byte> d)
{
ReceiveData(d, readState, delegate(FrameType op, byte[] data)
{
ProcessFrame(op, data, onMessage, onClose, onBinary, onPing, onPong);
});
}
};
}
public static byte[] FrameData(byte[] payload, FrameType frameType)
{
MemoryStream memoryStream = new MemoryStream();
byte value = (byte)(frameType + 128);
memoryStream.WriteByte(value);
if (payload.Length > 65535)
{
memoryStream.WriteByte(127);
byte[] array = payload.Length.ToBigEndianBytes<ulong>();
memoryStream.Write(array, 0, array.Length);
}
else if (payload.Length > 125)
{
memoryStream.WriteByte(126);
byte[] array2 = payload.Length.ToBigEndianBytes<ushort>();
memoryStream.Write(array2, 0, array2.Length);
}
else
{
memoryStream.WriteByte((byte)payload.Length);
}
memoryStream.Write(payload, 0, payload.Length);
return memoryStream.ToArray();
}
public static void ReceiveData(List<byte> data, ReadState readState, Action<FrameType, byte[]> processFrame)
{
while (data.Count >= 2)
{
bool flag = (data[0] & 0x80) != 0;
int num = data[0] & 0x70;
FrameType frameType = (FrameType)(data[0] & 0xFu);
bool num2 = (data[1] & 0x80) != 0;
int num3 = data[1] & 0x7F;
if (!num2 || !Enum.IsDefined(typeof(FrameType), frameType) || num != 0 || (frameType == FrameType.Continuation && !readState.FrameType.HasValue))
{
throw new WebSocketException(1002);
}
int num4 = 2;
int num5;
switch (num3)
{
case 127:
if (data.Count < num4 + 8)
{
return;
}
num5 = data.Skip(num4).Take(8).ToArray()
.ToLittleEndianInt();
num4 += 8;
break;
case 126:
if (data.Count < num4 + 2)
{
return;
}
num5 = data.Skip(num4).Take(2).ToArray()
.ToLittleEndianInt();
num4 += 2;
break;
default:
num5 = num3;
break;
}
if (data.Count < num4 + 4)
{
break;
}
byte[] array = data.Skip(num4).Take(4).ToArray();
num4 += 4;
if (data.Count < num4 + num5)
{
break;
}
byte[] array2 = new byte[num5];
for (int i = 0; i < num5; i++)
{
array2[i] = (byte)(data[num4 + i] ^ array[i % 4]);
}
readState.Data.AddRange(array2);
data.RemoveRange(0, num4 + num5);
if (frameType != 0)
{
readState.FrameType = frameType;
}
if (flag && readState.FrameType.HasValue)
{
byte[] arg = readState.Data.ToArray();
FrameType? frameType2 = readState.FrameType;
readState.Clear();
processFrame(frameType2.Value, arg);
}
}
}
public static void ProcessFrame(FrameType frameType, byte[] data, Action<string> onMessage, Action onClose, Action<byte[]> onBinary, Action<byte[]> onPing, Action<byte[]> onPong)
{
switch (frameType)
{
case FrameType.Close:
if (data.Length == 1 || data.Length > 125)
{
throw new WebSocketException(1002);
}
if (data.Length >= 2)
{
ushort num = (ushort)data.Take(2).ToArray().ToLittleEndianInt();
if (!WebSocketStatusCodes.ValidCloseCodes.Contains(num) && (num < 3000 || num > 4999))
{
throw new WebSocketException(1002);
}
}
if (data.Length > 2)
{
ReadUTF8PayloadData(data.Skip(2).ToArray());
}
onClose();
break;
case FrameType.Binary:
onBinary(data);
break;
case FrameType.Ping:
onPing(data);
break;
case FrameType.Pong:
onPong(data);
break;
case FrameType.Text:
onMessage(ReadUTF8PayloadData(data));
break;
default:
FleckLog.Debug("Received unhandled " + frameType);
break;
}
}
public static byte[] BuildHandshake(WebSocketHttpRequest request, string subProtocol)
{
FleckLog.Debug("Building Hybi-14 Response");
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.Append("HTTP/1.1 101 Switching Protocols\r\n");
stringBuilder.Append("Upgrade: websocket\r\n");
stringBuilder.Append("Connection: Upgrade\r\n");
if (subProtocol != null)
{
stringBuilder.AppendFormat("Sec-WebSocket-Protocol: {0}\r\n", subProtocol);
}
string arg = CreateResponseKey(request["Sec-WebSocket-Key"]);
stringBuilder.AppendFormat("Sec-WebSocket-Accept: {0}\r\n", arg);
stringBuilder.Append("\r\n");
return Encoding.ASCII.GetBytes(stringBuilder.ToString());
}
public static string CreateResponseKey(string requestKey)
{
string s = requestKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
return Convert.ToBase64String(SHA1.Create().ComputeHash(Encoding.ASCII.GetBytes(s)));
}
private static string ReadUTF8PayloadData(byte[] bytes)
{
UTF8Encoding uTF8Encoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);
try
{
return uTF8Encoding.GetString(bytes);
}
catch (ArgumentException)
{
throw new WebSocketException(1007);
}
}
}
}