|
@@ -31,63 +31,74 @@ func TestSecWebSocketAccept(t *testing.T) {
|
|
|
}
|
|
|
|
|
|
func TestHybiClientHandshake(t *testing.T) {
|
|
|
- b := bytes.NewBuffer([]byte{})
|
|
|
- bw := bufio.NewWriter(b)
|
|
|
- br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
|
|
|
+ type test struct {
|
|
|
+ url, host string
|
|
|
+ }
|
|
|
+ tests := []test{
|
|
|
+ {"ws://server.example.com/chat", "server.example.com"},
|
|
|
+ {"ws://127.0.0.1/chat", "127.0.0.1"},
|
|
|
+ }
|
|
|
+ if _, err := url.ParseRequestURI("http://[fe80::1%25lo0]"); err == nil {
|
|
|
+ tests = append(tests, test{"ws://[fe80::1%25lo0]/chat", "[fe80::1]"})
|
|
|
+ }
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ var b bytes.Buffer
|
|
|
+ bw := bufio.NewWriter(&b)
|
|
|
+ br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
|
|
|
Upgrade: websocket
|
|
|
Connection: Upgrade
|
|
|
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
|
|
|
Sec-WebSocket-Protocol: chat
|
|
|
|
|
|
`))
|
|
|
- var err error
|
|
|
- config := new(Config)
|
|
|
- config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
|
|
|
- if err != nil {
|
|
|
- t.Fatal("location url", err)
|
|
|
- }
|
|
|
- config.Origin, err = url.ParseRequestURI("http://example.com")
|
|
|
- if err != nil {
|
|
|
- t.Fatal("origin url", err)
|
|
|
- }
|
|
|
- config.Protocol = append(config.Protocol, "chat")
|
|
|
- config.Protocol = append(config.Protocol, "superchat")
|
|
|
- config.Version = ProtocolVersionHybi13
|
|
|
-
|
|
|
- config.handshakeData = map[string]string{
|
|
|
- "key": "dGhlIHNhbXBsZSBub25jZQ==",
|
|
|
- }
|
|
|
- err = hybiClientHandshake(config, br, bw)
|
|
|
- if err != nil {
|
|
|
- t.Errorf("handshake failed: %v", err)
|
|
|
- }
|
|
|
- req, err := http.ReadRequest(bufio.NewReader(b))
|
|
|
- if err != nil {
|
|
|
- t.Fatalf("read request: %v", err)
|
|
|
- }
|
|
|
- if req.Method != "GET" {
|
|
|
- t.Errorf("request method expected GET, but got %q", req.Method)
|
|
|
- }
|
|
|
- if req.URL.Path != "/chat" {
|
|
|
- t.Errorf("request path expected /chat, but got %q", req.URL.Path)
|
|
|
- }
|
|
|
- if req.Proto != "HTTP/1.1" {
|
|
|
- t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
|
|
|
- }
|
|
|
- if req.Host != "server.example.com" {
|
|
|
- t.Errorf("request Host expected server.example.com, but got %v", req.Host)
|
|
|
- }
|
|
|
- var expectedHeader = map[string]string{
|
|
|
- "Connection": "Upgrade",
|
|
|
- "Upgrade": "websocket",
|
|
|
- "Sec-Websocket-Key": config.handshakeData["key"],
|
|
|
- "Origin": config.Origin.String(),
|
|
|
- "Sec-Websocket-Protocol": "chat, superchat",
|
|
|
- "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
|
|
|
- }
|
|
|
- for k, v := range expectedHeader {
|
|
|
- if req.Header.Get(k) != v {
|
|
|
- t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
|
|
|
+ var err error
|
|
|
+ var config Config
|
|
|
+ config.Location, err = url.ParseRequestURI(tt.url)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("location url", err)
|
|
|
+ }
|
|
|
+ config.Origin, err = url.ParseRequestURI("http://example.com")
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("origin url", err)
|
|
|
+ }
|
|
|
+ config.Protocol = append(config.Protocol, "chat")
|
|
|
+ config.Protocol = append(config.Protocol, "superchat")
|
|
|
+ config.Version = ProtocolVersionHybi13
|
|
|
+ config.handshakeData = map[string]string{
|
|
|
+ "key": "dGhlIHNhbXBsZSBub25jZQ==",
|
|
|
+ }
|
|
|
+ if err := hybiClientHandshake(&config, br, bw); err != nil {
|
|
|
+ t.Fatal("handshake", err)
|
|
|
+ }
|
|
|
+ req, err := http.ReadRequest(bufio.NewReader(&b))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal("read request", err)
|
|
|
+ }
|
|
|
+ if req.Method != "GET" {
|
|
|
+ t.Errorf("request method expected GET, but got %s", req.Method)
|
|
|
+ }
|
|
|
+ if req.URL.Path != "/chat" {
|
|
|
+ t.Errorf("request path expected /chat, but got %s", req.URL.Path)
|
|
|
+ }
|
|
|
+ if req.Proto != "HTTP/1.1" {
|
|
|
+ t.Errorf("request proto expected HTTP/1.1, but got %s", req.Proto)
|
|
|
+ }
|
|
|
+ if req.Host != tt.host {
|
|
|
+ t.Errorf("request host expected %s, but got %s", tt.host, req.Host)
|
|
|
+ }
|
|
|
+ var expectedHeader = map[string]string{
|
|
|
+ "Connection": "Upgrade",
|
|
|
+ "Upgrade": "websocket",
|
|
|
+ "Sec-Websocket-Key": config.handshakeData["key"],
|
|
|
+ "Origin": config.Origin.String(),
|
|
|
+ "Sec-Websocket-Protocol": "chat, superchat",
|
|
|
+ "Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
|
|
|
+ }
|
|
|
+ for k, v := range expectedHeader {
|
|
|
+ if req.Header.Get(k) != v {
|
|
|
+ t.Errorf("%s expected %s, but got %v", k, v, req.Header.Get(k))
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|