diff -r 000000000000 -r 4f2f89ce4247 WebCore/websockets/WebSocketHandshake.cpp --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/WebCore/websockets/WebSocketHandshake.cpp Fri Sep 17 09:02:29 2010 +0300 @@ -0,0 +1,597 @@ +/* + * Copyright (C) 2009 Google Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "config.h" + +#if ENABLE(WEB_SOCKETS) + +#include "WebSocketHandshake.h" + +#include "AtomicString.h" +#include "CharacterNames.h" +#include "Cookie.h" +#include "CookieJar.h" +#include "Document.h" +#include "HTTPHeaderMap.h" +#include "KURL.h" +#include "Logging.h" +#include "ScriptExecutionContext.h" +#include "SecurityOrigin.h" +#include "StringBuilder.h" + +#include +#include +#include +#include +#include +#include + +namespace WebCore { + +static const char randomCharacterInSecWebSocketKey[] = "!\"#$%&'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~"; + +static String resourceName(const KURL& url) +{ + String name = url.path(); + if (name.isEmpty()) + name = "/"; + if (!url.query().isNull()) + name += "?" + url.query(); + ASSERT(!name.isEmpty()); + ASSERT(!name.contains(' ')); + return name; +} + +static String hostName(const KURL& url, bool secure) +{ + ASSERT(url.protocolIs("wss") == secure); + StringBuilder builder; + builder.append(url.host().lower()); + if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) { + builder.append(":"); + builder.append(String::number(url.port())); + } + return builder.toString(); +} + +static const size_t maxConsoleMessageSize = 128; +static String trimConsoleMessage(const char* p, size_t len) +{ + String s = String(p, std::min(len, maxConsoleMessageSize)); + if (len > maxConsoleMessageSize) + s.append(horizontalEllipsis); + return s; +} + +static void generateSecWebSocketKey(uint32_t& number, String& key) +{ + uint32_t space = static_cast(randomNumber() * 12) + 1; + uint32_t max = 4294967295U / space; + number = static_cast(randomNumber() * max); + uint32_t product = number * space; + + String s = String::number(product); + int n = static_cast(randomNumber() * 12) + 1; + DEFINE_STATIC_LOCAL(String, randomChars, (randomCharacterInSecWebSocketKey)); + for (int i = 0; i < n; i++) { + int pos = static_cast(randomNumber() * (s.length() + 1)); + int chpos = static_cast(randomNumber() * randomChars.length()); + s.insert(randomChars.substring(chpos, 1), pos); + } + DEFINE_STATIC_LOCAL(String, spaceChar, (" ")); + for (uint32_t i = 0; i < space; i++) { + int pos = static_cast(randomNumber() * (s.length() - 1)) + 1; + s.insert(spaceChar, pos); + } + ASSERT(s[0] != ' '); + ASSERT(s[s.length() - 1] != ' '); + key = s; +} + +static void generateKey3(unsigned char key3[8]) +{ + for (int i = 0; i < 8; i++) + key3[i] = randomNumber() * 256; +} + +static void setChallengeNumber(unsigned char* buf, uint32_t number) +{ + unsigned char* p = buf + 3; + for (int i = 0; i < 4; i++) { + *p = number & 0xFF; + --p; + number >>= 8; + } +} + +static void generateExpectedChallengeResponse(uint32_t number1, uint32_t number2, unsigned char key3[8], unsigned char expectedChallenge[16]) +{ + unsigned char challenge[16]; + setChallengeNumber(&challenge[0], number1); + setChallengeNumber(&challenge[4], number2); + memcpy(&challenge[8], key3, 8); + MD5 md5; + md5.addBytes(challenge, sizeof(challenge)); + Vector digest; + md5.checksum(digest); + memcpy(expectedChallenge, digest.data(), 16); +} + +WebSocketHandshake::WebSocketHandshake(const KURL& url, const String& protocol, ScriptExecutionContext* context) + : m_url(url) + , m_clientProtocol(protocol) + , m_secure(m_url.protocolIs("wss")) + , m_context(context) + , m_mode(Incomplete) +{ + uint32_t number1; + uint32_t number2; + generateSecWebSocketKey(number1, m_secWebSocketKey1); + generateSecWebSocketKey(number2, m_secWebSocketKey2); + generateKey3(m_key3); + generateExpectedChallengeResponse(number1, number2, m_key3, m_expectedChallengeResponse); +} + +WebSocketHandshake::~WebSocketHandshake() +{ +} + +const KURL& WebSocketHandshake::url() const +{ + return m_url; +} + +void WebSocketHandshake::setURL(const KURL& url) +{ + m_url = url.copy(); +} + +const String WebSocketHandshake::host() const +{ + return m_url.host().lower(); +} + +const String& WebSocketHandshake::clientProtocol() const +{ + return m_clientProtocol; +} + +void WebSocketHandshake::setClientProtocol(const String& protocol) +{ + m_clientProtocol = protocol; +} + +bool WebSocketHandshake::secure() const +{ + return m_secure; +} + +String WebSocketHandshake::clientOrigin() const +{ + return m_context->securityOrigin()->toString(); +} + +String WebSocketHandshake::clientLocation() const +{ + StringBuilder builder; + builder.append(m_secure ? "wss" : "ws"); + builder.append("://"); + builder.append(hostName(m_url, m_secure)); + builder.append(resourceName(m_url)); + return builder.toString(); +} + +CString WebSocketHandshake::clientHandshakeMessage() const +{ + // Keep the following consistent with clientHandshakeRequest(). + StringBuilder builder; + + builder.append("GET "); + builder.append(resourceName(m_url)); + builder.append(" HTTP/1.1\r\n"); + + Vector fields; + fields.append("Upgrade: WebSocket"); + fields.append("Connection: Upgrade"); + fields.append("Host: " + hostName(m_url, m_secure)); + fields.append("Origin: " + clientOrigin()); + if (!m_clientProtocol.isEmpty()) + fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol); + + KURL url = httpURLForAuthenticationAndCookies(); + if (m_context->isDocument()) { + Document* document = static_cast(m_context); + String cookie = cookieRequestHeaderFieldValue(document, url); + if (!cookie.isEmpty()) + fields.append("Cookie: " + cookie); + // Set "Cookie2: " if cookies 2 exists for url? + } + + fields.append("Sec-WebSocket-Key1: " + m_secWebSocketKey1); + fields.append("Sec-WebSocket-Key2: " + m_secWebSocketKey2); + + // Fields in the handshake are sent by the client in a random order; the + // order is not meaningful. Thus, it's ok to send the order we constructed + // the fields. + + for (size_t i = 0; i < fields.size(); i++) { + builder.append(fields[i]); + builder.append("\r\n"); + } + + builder.append("\r\n"); + + CString handshakeHeader = builder.toString().utf8(); + char* characterBuffer = 0; + CString msg = CString::newUninitialized(handshakeHeader.length() + sizeof(m_key3), characterBuffer); + memcpy(characterBuffer, handshakeHeader.data(), handshakeHeader.length()); + memcpy(characterBuffer + handshakeHeader.length(), m_key3, sizeof(m_key3)); + return msg; +} + +WebSocketHandshakeRequest WebSocketHandshake::clientHandshakeRequest() const +{ + // Keep the following consistent with clientHandshakeMessage(). + // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and + // m_key3 in WebSocketHandshakeRequest? + WebSocketHandshakeRequest request("GET", m_url); + request.addHeaderField("Upgrade", "WebSocket"); + request.addHeaderField("Connection", "Upgrade"); + request.addHeaderField("Host", hostName(m_url, m_secure)); + request.addHeaderField("Origin", clientOrigin()); + if (!m_clientProtocol.isEmpty()) + request.addHeaderField("Sec-WebSocket-Protocol:", m_clientProtocol); + + KURL url = httpURLForAuthenticationAndCookies(); + if (m_context->isDocument()) { + Document* document = static_cast(m_context); + String cookie = cookieRequestHeaderFieldValue(document, url); + if (!cookie.isEmpty()) + request.addHeaderField("Cookie", cookie); + // Set "Cookie2: " if cookies 2 exists for url? + } + + request.addHeaderField("Sec-WebSocket-Key1", m_secWebSocketKey1); + request.addHeaderField("Sec-WebSocket-Key2", m_secWebSocketKey2); + request.setKey3(m_key3); + + return request; +} + +void WebSocketHandshake::reset() +{ + m_mode = Incomplete; + + m_wsOrigin = String(); + m_wsLocation = String(); + m_wsProtocol = String(); + m_setCookie = String(); + m_setCookie2 = String(); +} + +void WebSocketHandshake::clearScriptExecutionContext() +{ + m_context = 0; +} + +int WebSocketHandshake::readServerHandshake(const char* header, size_t len) +{ + m_mode = Incomplete; + int statusCode; + String statusText; + int lineLength = readStatusLine(header, len, statusCode, statusText); + if (lineLength == -1) + return -1; + if (statusCode == -1) { + m_mode = Failed; + return len; + } + LOG(Network, "response code: %d", statusCode); + m_response.setStatusCode(statusCode); + m_response.setStatusText(statusText); + if (statusCode != 101) { + m_mode = Failed; + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, String::format("Unexpected response code: %d", statusCode), 0, clientOrigin()); + return len; + } + m_mode = Normal; + if (!strnstr(header, "\r\n\r\n", len)) { + // Just hasn't been received fully yet. + m_mode = Incomplete; + return -1; + } + const char* p = readHTTPHeaders(header + lineLength, header + len); + if (!p) { + LOG(Network, "readHTTPHeaders failed"); + m_mode = Failed; + return len; + } + processHeaders(); + if (!checkResponseHeaders()) { + LOG(Network, "header process failed"); + m_mode = Failed; + return p - header; + } + if (len < static_cast(p - header + sizeof(m_expectedChallengeResponse))) { + // Just hasn't been received /expected/ yet. + m_mode = Incomplete; + return -1; + } + m_response.setChallengeResponse(static_cast(static_cast(p))); + if (memcmp(p, m_expectedChallengeResponse, sizeof(m_expectedChallengeResponse))) { + m_mode = Failed; + return (p - header) + sizeof(m_expectedChallengeResponse); + } + m_mode = Connected; + return (p - header) + sizeof(m_expectedChallengeResponse); +} + +WebSocketHandshake::Mode WebSocketHandshake::mode() const +{ + return m_mode; +} + +const String& WebSocketHandshake::serverWebSocketOrigin() const +{ + return m_wsOrigin; +} + +void WebSocketHandshake::setServerWebSocketOrigin(const String& webSocketOrigin) +{ + m_wsOrigin = webSocketOrigin; +} + +const String& WebSocketHandshake::serverWebSocketLocation() const +{ + return m_wsLocation; +} + +void WebSocketHandshake::setServerWebSocketLocation(const String& webSocketLocation) +{ + m_wsLocation = webSocketLocation; +} + +const String& WebSocketHandshake::serverWebSocketProtocol() const +{ + return m_wsProtocol; +} + +void WebSocketHandshake::setServerWebSocketProtocol(const String& webSocketProtocol) +{ + m_wsProtocol = webSocketProtocol; +} + +const String& WebSocketHandshake::serverSetCookie() const +{ + return m_setCookie; +} + +void WebSocketHandshake::setServerSetCookie(const String& setCookie) +{ + m_setCookie = setCookie; +} + +const String& WebSocketHandshake::serverSetCookie2() const +{ + return m_setCookie2; +} + +void WebSocketHandshake::setServerSetCookie2(const String& setCookie2) +{ + m_setCookie2 = setCookie2; +} + +const WebSocketHandshakeResponse& WebSocketHandshake::serverHandshakeResponse() const +{ + return m_response; +} + +KURL WebSocketHandshake::httpURLForAuthenticationAndCookies() const +{ + KURL url = m_url.copy(); + bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http"); + ASSERT_UNUSED(couldSetProtocol, couldSetProtocol); + return url; +} + +// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet. +// If the line is malformed or the status code is not a 3-digit number, +// statusCode and statusText will be set to -1 and a null string, respectively. +int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, int& statusCode, String& statusText) +{ + statusCode = -1; + statusText = String(); + + const char* space1 = 0; + const char* space2 = 0; + const char* p; + size_t consumedLength; + + for (p = header, consumedLength = 0; consumedLength < headerLength; p++, consumedLength++) { + if (*p == ' ') { + if (!space1) + space1 = p; + else if (!space2) + space2 = p; + } else if (*p == '\n') + break; + } + if (consumedLength == headerLength) + return -1; // We have not received '\n' yet. + + const char* end = p + 1; + if (end - header > INT_MAX) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line is too long: " + trimConsoleMessage(header, maxConsoleMessageSize + 1), 0, clientOrigin()); + return INT_MAX; + } + int lineLength = end - header; + + if (!space1 || !space2) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "No response code found: " + trimConsoleMessage(header, lineLength - 1), 0, clientOrigin()); + return lineLength; + } + + // The line must end with "\r\n". + if (*(end - 2) != '\r') { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Status line does not end with CRLF", 0, clientOrigin()); + return lineLength; + } + + String statusCodeString(space1 + 1, space2 - space1 - 1); + if (statusCodeString.length() != 3) // Status code must consist of three digits. + return lineLength; + for (int i = 0; i < 3; ++i) + if (statusCodeString[i] < '0' || statusCodeString[i] > '9') { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Invalid status code: " + statusCodeString, 0, clientOrigin()); + return lineLength; + } + + bool ok = false; + statusCode = statusCodeString.toInt(&ok); + ASSERT(ok); + + statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n". + return lineLength; +} + +const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end) +{ + m_response.clearHeaderFields(); + + Vector name; + Vector value; + for (const char* p = start; p < end; p++) { + name.clear(); + value.clear(); + + for (; p < end; p++) { + switch (*p) { + case '\r': + if (name.isEmpty()) { + if (p + 1 < end && *(p + 1) == '\n') + return p + 2; + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF at " + trimConsoleMessage(p, end - p), 0, clientOrigin()); + return 0; + } + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected CR in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin()); + return 0; + case '\n': + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in name at " + trimConsoleMessage(name.data(), name.size()), 0, clientOrigin()); + return 0; + case ':': + break; + default: + name.append(*p); + continue; + } + if (*p == ':') { + ++p; + break; + } + } + + for (; p < end && *p == 0x20; p++) { } + + for (; p < end; p++) { + switch (*p) { + case '\r': + break; + case '\n': + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Unexpected LF in value at " + trimConsoleMessage(value.data(), value.size()), 0, clientOrigin()); + return 0; + default: + value.append(*p); + } + if (*p == '\r') { + ++p; + break; + } + } + if (p >= end || *p != '\n') { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "CR doesn't follow LF after value at " + trimConsoleMessage(p, end - p), 0, clientOrigin()); + return 0; + } + AtomicString nameStr(String::fromUTF8(name.data(), name.size())); + String valueStr = String::fromUTF8(value.data(), value.size()); + if (nameStr.isNull()) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header name", 0, clientOrigin()); + return 0; + } + if (valueStr.isNull()) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "invalid UTF-8 sequence in header value", 0, clientOrigin()); + return 0; + } + LOG(Network, "name=%s value=%s", nameStr.string().utf8().data(), valueStr.utf8().data()); + m_response.addHeaderField(nameStr, valueStr); + } + ASSERT_NOT_REACHED(); + return 0; +} + +void WebSocketHandshake::processHeaders() +{ + ASSERT(m_mode == Normal); + const HTTPHeaderMap& headers = m_response.headerFields(); + m_wsOrigin = headers.get("sec-websocket-origin"); + m_wsLocation = headers.get("sec-websocket-location"); + m_wsProtocol = headers.get("sec-websocket-protocol"); + m_setCookie = headers.get("set-cookie"); + m_setCookie2 = headers.get("set-cookie2"); +} + +bool WebSocketHandshake::checkResponseHeaders() +{ + if (m_wsOrigin.isNull()) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-origin' header is missing", 0, clientOrigin()); + return false; + } + if (m_wsLocation.isNull()) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: 'sec-websocket-location' header is missing", 0, clientOrigin()); + return false; + } + + if (clientOrigin() != m_wsOrigin) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: origin mismatch: " + clientOrigin() + " != " + m_wsOrigin, 0, clientOrigin()); + return false; + } + if (clientLocation() != m_wsLocation) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: location mismatch: " + clientLocation() + " != " + m_wsLocation, 0, clientOrigin()); + return false; + } + if (!m_clientProtocol.isEmpty() && m_clientProtocol != m_wsProtocol) { + m_context->addMessage(JSMessageSource, LogMessageType, ErrorMessageLevel, "Error during WebSocket handshake: protocol mismatch: " + m_clientProtocol + " != " + m_wsProtocol, 0, clientOrigin()); + return false; + } + return true; +} + +} // namespace WebCore + +#endif // ENABLE(WEB_SOCKETS)