| package common |
|
|
| import ( |
| "testing" |
|
|
| "github.com/QuantumNous/new-api/constant" |
| ) |
|
|
| func TestValidateRedirectURL(t *testing.T) { |
| |
| originalDomains := constant.TrustedRedirectDomains |
| defer func() { |
| constant.TrustedRedirectDomains = originalDomains |
| }() |
|
|
| tests := []struct { |
| name string |
| url string |
| trustedDomains []string |
| wantErr bool |
| errContains string |
| }{ |
| |
| { |
| name: "exact domain match with https", |
| url: "https://example.com/success", |
| trustedDomains: []string{"example.com"}, |
| wantErr: false, |
| }, |
| { |
| name: "exact domain match with http", |
| url: "http://example.com/callback", |
| trustedDomains: []string{"example.com"}, |
| wantErr: false, |
| }, |
| { |
| name: "subdomain match", |
| url: "https://sub.example.com/success", |
| trustedDomains: []string{"example.com"}, |
| wantErr: false, |
| }, |
| { |
| name: "case insensitive domain", |
| url: "https://EXAMPLE.COM/success", |
| trustedDomains: []string{"example.com"}, |
| wantErr: false, |
| }, |
|
|
| |
| { |
| name: "untrusted domain", |
| url: "https://evil.com/phishing", |
| trustedDomains: []string{"example.com"}, |
| wantErr: true, |
| errContains: "not in the trusted domains list", |
| }, |
| { |
| name: "suffix attack - fakeexample.com", |
| url: "https://fakeexample.com/success", |
| trustedDomains: []string{"example.com"}, |
| wantErr: true, |
| errContains: "not in the trusted domains list", |
| }, |
| { |
| name: "empty trusted domains list", |
| url: "https://example.com/success", |
| trustedDomains: []string{}, |
| wantErr: true, |
| errContains: "not in the trusted domains list", |
| }, |
|
|
| |
| { |
| name: "javascript scheme", |
| url: "javascript:alert('xss')", |
| trustedDomains: []string{"example.com"}, |
| wantErr: true, |
| errContains: "invalid URL scheme", |
| }, |
| { |
| name: "data scheme", |
| url: "data:text/html,<script>alert('xss')</script>", |
| trustedDomains: []string{"example.com"}, |
| wantErr: true, |
| errContains: "invalid URL scheme", |
| }, |
|
|
| |
| { |
| name: "empty URL", |
| url: "", |
| trustedDomains: []string{"example.com"}, |
| wantErr: true, |
| errContains: "invalid URL scheme", |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| |
| constant.TrustedRedirectDomains = tt.trustedDomains |
|
|
| err := ValidateRedirectURL(tt.url) |
|
|
| if tt.wantErr { |
| if err == nil { |
| t.Errorf("ValidateRedirectURL(%q) expected error containing %q, got nil", tt.url, tt.errContains) |
| return |
| } |
| if tt.errContains != "" && !contains(err.Error(), tt.errContains) { |
| t.Errorf("ValidateRedirectURL(%q) error = %q, want error containing %q", tt.url, err.Error(), tt.errContains) |
| } |
| } else { |
| if err != nil { |
| t.Errorf("ValidateRedirectURL(%q) unexpected error: %v", tt.url, err) |
| } |
| } |
| }) |
| } |
| } |
|
|
| func contains(s, substr string) bool { |
| return len(s) >= len(substr) && (s == substr || len(substr) == 0 || |
| (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) |
| } |
|
|
| func findSubstring(s, substr string) bool { |
| for i := 0; i <= len(s)-len(substr); i++ { |
| if s[i:i+len(substr)] == substr { |
| return true |
| } |
| } |
| return false |
| } |
|
|