Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 16 additions & 25 deletions lib/onelogin/ruby-saml/saml_message.rb
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
require 'cgi'
require 'zlib'
require 'base64'
require "nokogiri"
require "rexml/document"
require "rexml/xpath"
require "thread"
require 'nokogiri'
require 'rexml/document'
require 'rexml/xpath'
require 'thread'

module OneLogin
module RubySaml
Expand All @@ -14,7 +14,7 @@ class SamlMessage
ASSERTION = "urn:oasis:names:tc:SAML:2.0:assertion"
PROTOCOL = "urn:oasis:names:tc:SAML:2.0:protocol"

BASE64_FORMAT_REGEXP = %r{\A(([A-Za-z0-9+/]{4}))*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)\Z}
BASE64_FORMAT = %r(\A[A-Za-z0-9+/]{4}*[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=?\Z)

def self.schema
@schema ||= Mutex.new.synchronize do
Expand Down Expand Up @@ -48,7 +48,7 @@ def validation_error(message)
# is to try and inflate it and fall back to the base64 decoded string if
# the stream contains errors.
def decode_raw_saml(saml)
return saml unless base64_formatted?(saml)
return saml unless base64_encoded?(saml)

decoded = decode(saml)
begin
Expand All @@ -59,9 +59,9 @@ def decode_raw_saml(saml)
end

def encode_raw_saml(saml, settings)
saml = Zlib::Deflate.deflate(saml, 9)[2..-5] if settings.compress_request
base64_saml = Base64.encode64(saml)
return CGI.escape(base64_saml)
saml = deflate(saml) if settings.compress_request

CGI.escape(Base64.encode64(saml))
end

def decode(encoded)
Expand All @@ -72,30 +72,21 @@ def encode(encoded)
Base64.encode64(encoded).gsub(/\n/, "")
end

# Check if the provided string is base64 encoded.
# @param message [String] The value to be checked.
# @return [Boolean] True if the value is a base64 encoded string.
def base64_formatted?(string)
string.gsub(/[\r\n]|\\r|\\n/, "").match(BASE64_FORMAT_REGEXP)
end

def escape(unescaped)
CGI.escape(unescaped)
end

def unescape(escaped)
CGI.unescape(escaped)
# Check if a string is base64 encoded
#
# @param string [String] string to check the encoding of
# @return [true, false] whether or not the string is base64 encoded
def base64_encoded?(string)
!!string.gsub(/[\r\n]|\\r|\\n/, "").match(BASE64_FORMAT)
end

def inflate(deflated)
zlib = Zlib::Inflate.new(-Zlib::MAX_WBITS)
zlib.inflate(deflated)
Zlib::Inflate.new(-Zlib::MAX_WBITS).inflate(deflated)
end

def deflate(inflated)
Zlib::Deflate.deflate(inflated, 9)[2..-5]
end

end
end
end