#
# $Id: httpclient.icn,v 1.13 2004/10/15 22:17:14 rparlett Exp $
#

package http

import util
import lang
import net
import mail

$include "posix.icn"

global token_char

#
# Little helper class to store cookie details
#
class Cookie:Object(name, value, expires, domain, path, secure)
end

#
# An http client
#
class HttpClient : NetClient(request,
                             result,
                             open_url,
                             data_out,
                             retries,
                             length,
                             string_buff,
                             read,
                             http_version,
                             http_error,
                             keep_alive_flag,
                             user_agent,
                             redir_set, 
                             auth_req_count,
                             auth_scheme,
                             basic_auth_header,
                             nonce,
                             nonce_count,
                             cnonce,
                             opaque,
                             ha1,
                             qop,
                             realm,
                             cookies
                             )

   #
   # Set the http version to use; by default "1.0"
   #
   method set_http_version(s)
      return http_version := s
   end

   #
   # Set the user-agent identification
   #
   method set_user_agent(s)
      return self.user_agent := s
   end

   #
   # Configure the client to use the keep-alive feature (the
   # default).
   #
   method set_keep_alive()
      self.keep_alive_flag := 1
   end

   #
   # Configure the client to NOT use the keep-alive feature.
   #
   method clear_keep_alive()
      self.keep_alive_flag := &null
   end

   #
   # @p
   method write_request_headers()
      local e, v
      every e := !request.headers.sort() do {
         e[1][1] := map(e[1][1], &lcase, &ucase)
         every v := !e[2] do
            write_line(e[1] || ": " || v) | fail
      }
      return
   end

   #
   # Set the number of retries to use.  The default is two.
   #
   method set_retries(retries)
      return self.retries := retries
   end

   #
   # Retrieve the given {HttpRequest} request, or fail if that
   # is not possible.
   #
   # @param request - an {HttpRequest} instance
   # @return an {HttpResponse} object.
   method retrieve(request)
      self.request := request.clone()
      auth_req_count := 0
      redir_set := set()
      repeat {
         http_error := &null

         retrieve_page() | fail
         if find("200"|"206", result.get_status()) then
            return result

         remove(\request.data_file)

         if find("401", result.get_status()) then
            handle_authentication() | fail
         else if find("301"|"302"|"303"|"307", result.get_status()) then
            handle_redirect() | fail
         else
            return on_http_error(result.get_status())
      }
   end

   #
   # @p
   method handle_redirect()
      local l, s
      #
      # A redirection, so extract the URL if possible.
      #
      l := result.get_first_header("Location") | return on_http_error("No Location in a redirect response")
      if match("http://", l) then
         request.url := URL(l) | return on_http_error("Invalid Location in a redirect response:" || l)
      else
         #
         # It shouldn't be a relative URL, but that seems to be
         # commonplace.
         #
         request.url.set_relative(l)
      
      s := request.url.to_string()
      if member(redir_set, s) then
         return on_http_error("Circular redirection detected:" || s)
      insert(redir_set, s)

      #
      # On a redirect, a POST becomes a GET.
      #
      if request.the_method == "POST" then {
         request.the_method := "GET"
         request.post_data := &null
         request.unset_header("content-length")
         request.unset_header("content-type")
      }
      auth_req_count := 0
      return
   end

   #
   # @p
   method handle_authentication()
      local s, t
      if /(request.username | request.password) then
         return on_http_error("Authentication requested - please set username, password")

      if auth_req_count > 0 then
         return on_http_error("Failed to authenticate - correct username, password")
      auth_req_count +:= 1

      s := result.get_first_header("www-authenticate") | 
         return on_http_error("No WWW-Authenticate in a 401 response")

      t := parse_generic_header(s)
      if member(t, "Basic") then {
         auth_scheme := "Basic"
         return setup_basic_authentication()
      }
      if member(t, "Digest") then {
         auth_scheme := "Digest"
         return setup_digest_authentication()
      }

      return on_http_error("WWW-Authenticate header contained unknown authentication method.")
   end

   #
   # @p
   method setup_basic_authentication(t)
      local b64h, s
      b64h := Base64Handler()
      # Strip off any \r\n in result.
      s := ""
      b64h.encode_data(&null, request.username || ":" || request.password) ? repeat {
         s ||:= tab(find("\r\n") | 0)
         if pos(0) then
            break
         move(2)
      }
      basic_auth_header := "Basic " || s
      return
   end

   #
   # @p
   method setup_digest_authentication(t)
      local algorithm, md5, qop_options

      self.realm := \t["realm"] | 
         return on_http_error("WWW-Authenticate digest header didn't contain a realm")
      self.nonce := \t["nonce"] | 
         return on_http_error("WWW-Authenticate digest header didn't contain a nonce")
      self.opaque := t["opaque"] 

      algorithm := \t["algorithm"] | "MD5"
      self.nonce_count := 0
      self.cnonce := "0a4f113b"
      
      #
      # Calculate HA1
      #
      md5 := MD5()
      md5.update(unq(request.username) || ":" || unq(realm) || ":" || request.password)
      if map(unq(algorithm)) == "md5-sess" then {
         md5.update(":" || unq(nonce) || ":" || unq(cnonce))
      }
      self.ha1 := md5.final_str()

      #
      # Get the available qop values and select a qop.
      #
      qop_options := set()
      \t["qop"] ? {
         while(tab(upto(token_char))) do 
            insert(qop_options, map(tab(many(token_char))))
      }
      self.qop := member(qop_options, "auth-int" | "auth") | &null
   end

   #
   # @p
   method create_digest_authorization_header()
      local md5, nc, h

      self.nonce_count +:= 1
      nc := map(format_int_to_string(self.nonce_count, 16, 8))
      md5 := MD5()
      #
      # Calculate HA2
      #
      md5.update(request.the_method || ":" || request.url.get_file())
      if \qop == "auth-int" then {
         # Calculate H(entity)
         t := MD5()
         t.update(\request.post_data)
         hentity := t.final_str()
         md5.update(":" || hentity)
      }

      ha2 := md5.final_str()

      md5.update(ha1 || ":" || unq(nonce) || ":")
      if \self.qop then
         md5.update(nc || ":" || unq(cnonce) || ":" || unq(qop) || ":")

      md5.update(ha2)

      h := "Digest username=\"" || request.username || "\",\r\n\t" ||
         "realm=" || realm || ",\r\n\t" ||
         "nonce=" || nonce || ",\r\n\t" ||
         "uri=\"" || request.url.get_file() || "\",\r\n\t"
      if \qop then {
         h ||:= "qop=" || qop || ",\r\n\t" ||
            "nc=" || nc || ",\r\n\t" ||
            "cnonce=" || cnonce || ",\r\n\t"
      }
      if \opaque then
         h ||:= "opaque=" || opaque || ",\r\n\t"
      h ||:= "response=\"" || md5.final_str() || "\""

      return h
   end

   #
   # @p
   method unq(s)
      if s[1] == "\"" then 
         s[1] := ""
      if s[-1] == "\"" then 
         s[-1] := ""
      return s
   end

   #
   # @p
   method parse_generic_header(s)
      local t, k, v

      t := table()
      s ? repeat {
         tab(upto(token_char)) | break
         k := tab(many(token_char))
         if ="=" then {
            if any('\"') then
               v := parse_quoted_string()
            else
               v := tab(many(token_char)) | ""
         } else
            v := &null
         insert(t, k, v)
      }
      return t
   end

   #
   # @p
   method parse_quoted_string()
      local res
      res := ""
      res ||:= ="\""
      repeat {
         res ||:= tab(upto('\\\"') | 0)
         if pos(0) then
            break
         if res ||:= ="\"" then
            break
         if any('\\') then {
            res ||:= move(2)
         }
      }
      return res
   end

   #
   # @p
   method create_authorization_header()
      case auth_scheme of {
         "Basic":
            return create_basic_authorization_header()
         "Digest":
            return create_digest_authorization_header()
         default:
            fail
      }
   end

   #
   # @p
   method create_basic_authorization_header()
      return basic_auth_header
   end

   #
   # @p
   method parse_cookie_string(s)
      local c, k, v, t

      s ? repeat {
         tab(upto(token_char)) | break
         k := tab(many(token_char))
         if ="=" then {
            if any('\"') then
               v := parse_quoted_string()
            else
               v := tab(upto(';') | 0)
         } else
            v := &null
         if /c then {
            c := Cookie()
            c.name := k
            c.value := v
         } else {
            case map(k) of {
               "expires" : {
                  t := Time()
                  if t.parse(v, "EEE, dd-MMM-yyyy hh:mm:ss zzz") then
                     c.expires := t
               }
               "path" : c.path := v
               "domain" : c.domain := v
               "secure" : secure := 1
            }
         }
      }

      return \c
   end

   #
   # Store any cookie values in the current page's header.
   # @p
   method store_cookies()
      local c, domain, paths, path, e, vals

      every s := !result.get_headers("set-cookie") do {
         if c := parse_cookie_string(s) then {
            domain := map(\c.domain | request.url.get_address())
            if member(cookies, domain) then 
               paths := cookies[domain]
            else {
               paths := table()
               insert(cookies, domain, paths)
            }

            path := \c.path | request.url.get_file()

            if member(paths, path) then
               vals := paths[path]
            else {
               vals := ClTable()
               insert(paths, path, vals)
            }
            vals.insert(c.name, c)
         }
      }
   end

   #
   # @p
   method create_cookie_header()
      local s, e, f, host, path, c, now

      host := map(request.url.get_address())
      path := request.url.get_file()
      now := Time()
      now.set_current_time()
      s := ""
      every e := !sort(cookies) do {
         if host[0-:*e[1]] == e[1] then {
            every f := !reverse(sort(e[2])) do {
               if match(f[1], path) then {
                  every c := (!f[2].sort())[2] do {
                     if /c.expires | c.expires.after(now) then {
                        if *s > 0 then
                           s ||:= "; "
                        s ||:= c.name
                        s ||:= "=" || \c.value
                     }
                  }
               }
            }
         }
      }

      if *s > 0 then
         return s
   end

   #
   # Useful debug function
   # @p
   method dump_cookies()
      local c, e, f, g
      every e := !sort(cookies) do {
         write("domain:", e[1])
         every f := !reverse(sort(e[2])) do {
            write("\tpath:", f[1])
            every c := (!f[2].sort())[2] do {
               write("\t\t", c.to_string())
            }
         }
      }
   end

   #
   # @p
   method on_http_error(s)
      http_error := s
      return error(s)
   end

   #
   # After invoking retrieve(), this method can be used to determine
   # whether a failure was caused by a network failure or an HTTP failure
   # (for example 404 not found).  In the former case, &null is returned; in
   # the latter case the error string is returned.
   #
   method get_http_error()
      return http_error
   end

   #
   # @p
   method maybe_open()
      if /connection | /open_url | open_url.get_address() ~== request.url.get_address() |
         open_url.get_port() ~= request.url.get_port() then 
      {
         close()
         set_server(request.url.get_address())
         set_port(request.url.get_port())
         open() | fail
         open_url := request.url.clone()
      }
      return
   end

   #
   # Retrieve the current request; doesn't handle redirects and so
   # on - these are handled by the caller, retrieve().
   # @p
   method retrieve_page()
      local i

      i := retries

      self.read := self.length := &null

      repeat {
         if tryone() then {
            store_cookies()
            fire("Complete")
            return result
         }

         close()

         if (i -:= 1) < 0 then {
            remove(\request.data_file)
            fire("Failed")
            if retries = 0 then
               fail
            else
               return error("Gave up after " || (1+retries) || " attempts - last reason:" || get_reason())
         }

         fire("Retrying")
      }
   end

   #
   #
   # Retrieve the given url.  A {HttpResponse} object is returned upon success.
   # Failure occurs if the url could not be retrieved.
   #
   method tryone()
      #
      # Open the connection
      #
      maybe_open() | fail

      #
      # Write the request
      #
      write_request() | fail

      #
      # Create a result
      #
      result := HttpResponse()
      result.set_url(request.url)
      result.set_status(read_line()) | fail

      #
      # Read the response.
      #
      read_headers() | fail

      if request.the_method ~=== "HEAD" then
         read_data() | fail
      else
         maybe_close()

      return
   end

   #
   # @p
   method maybe_close()
      #
      # Close if told to do so.
      #
      if map(result.get_first_header("connection")) == "close" then
         close()
   end

   #
   # @p
   method write_request()
      request.set_header("user-agent", user_agent)
      request.set_header("host", request.url.get_address())
      request.set_header("referer", \request.referer)
      request.unset_header("cookie")
      request.set_header("cookie", create_cookie_header())

      if \keep_alive_flag then
         request.set_header("connection", "Keep-Alive")

      if \self.read > 0 then
         request.set_header("range", "bytes=" || self.read || "-")
      else
         request.unset_header("range")

      write_line(request.the_method || " " || request.url.get_file() || 
                 " HTTP/" || http_version) | fail

      if /request.post_data then {
         request.set_header("authorization", create_authorization_header())
         write_request_headers() | fail
         write_line() | fail
      } else {
         request.set_header("content-type", request.content_type)
         request.set_header("content-length", *request.post_data)
         request.set_header("authorization", create_authorization_header())
         write_request_headers() | fail
         write_line() | fail
         write_str(request.post_data) | fail
      }

      return
   end

   #
   # @p
   method read_headers()
      local s, key, val
      repeat {
         s := read_line() | fail
         if *s = 0 then {
            # Add last header (if any).
            result.add_header(\key, val)
            break
         }
         s ? {
            #
            # A continuation line starts with a space or a tab.
            #
            if any(' \t') then
               val ||:= tab(0)
            else {
               # Add current header and start a new one.
               result.add_header(\key, val)
               key := tab(upto(':') | 0)
               =": "
               val := tab(0)
            }
         }
      }
      return
   end

   #
   # @p
   method read_data()
      if find("206", result.get_status()) then {
         #
         # Continue a partial read.  The self.read and self.length values
         # continue from the previous attempt.
         #
         if \request.data_file then
            data_out := ::open(request.data_file, "a") | stop("Couldn't open ", request.data_file)
      } else {
         #
         # Set up either a new data file or a string buffer, and reset
         # the self.read and self.length values.
         #
         if /request.data_file then
            string_buff := StringBuff()
         else
            data_out := ::open(request.data_file, "w") | stop("Couldn't open ", request.data_file)

         self.length := integer(result.get_first_header("content-length")) | &null
         self.read := 0
      }

      if not read_data_impl() then {
         if \request.data_file then
            ::close(\data_out)
         fail
      }

      #
      # Close the result
      #
      if /request.data_file then 
         result.set_data(string_buff.get_string())
      else {
         ::close(\data_out)
         result.set_data_file(request.data_file)
      }

      return
   end

   #
   # @p
   method read_data_impl()
      #
      # Read the data.
      #
      if map(result.get_first_header("transfer-encoding")) == "chunked" then
         return read_chunked()
      else if /length then
         return read_to_eof()
      else
         return read_length()
   end

   #
   # @p
   method read_chunked()
      local l, chunk, s
      repeat {
         l := read_line() | fail
         l ? {
            chunk := integer("16r" || tab(many('0987654321abcdefABCDEF'))) |
               return error("Expected chunk-length specifier; got " || l)
         }
         if chunk = 0 then
            break
         while chunk > 0 do {
            s := read_str(min(1024, chunk)) | fail
            add_some(s)
            chunk -:= *s
         }
         l := read_line() | fail
         if *l > 0 then
            return error("Expected empty line at end of chunk")
      }
      #
      # Read any trailing lines, upto an empty line.
      #
      repeat {
         l := read_line() | fail
         if *l = 0 then
            break
      }

      maybe_close()

      return
   end

   #
   # @p
   method read_to_eof()
      #
      # No content length.  Read until eof, then close.
      #
      while s := read_str(1024) do
         add_some(s)

      #
      # Check for errors that may have caused read_str to fail.
      #
      if &errno ~= 0 then {
         close()
         fail
      }

      close()

      return
   end

   #
   # @p
   method read_length()
      #
      # Got a content length.  Read len bytes and leave the connection
      # open - unless told to close it.
      #
      while self.read < length do {
         s := read_str(min(1024, length - self.read)) | fail
         add_some(s)
      }

      maybe_close()

      return
   end

   #
   # Return the length of the current download, or &null if unknown
   #
   method get_length()
      return self.length
   end

   #
   # Return the number of bytes read so far.
   #
   method get_read()
      return self.read
   end

   #
   # @p
   method add_some(s)
      self.read +:= *s
      fire("Progress")
      if /request.data_file then
         string_buff.add(s)
      else
         writes(data_out, s)
   end

   method set_one(attr, val)
      local s, u
      case attr of {
         "http_version": set_http_version(string_val(attr, val))
         "user_agent": set_user_agent(string_val(attr, val))
         "retries": set_retries(int_val(attr, val))
         "keep_alive" :
            if test_flag(attr, val) then
               set_keep_alive()
            else
               clear_keep_alive()
         default: self.NetClient.set_one(attr, val)
      }
   end

   initially(a[])
      initial {
         token_char := &ascii[33:128] -- '()<>@,;:\\\"/[]?={} \t'
      }
      self.NetClient.initially()
      cookies := table()
      timeout := 12000
      keep_alive_flag := 1
      retries := 2
      http_version := "1.1"
      user_agent := "httpclient.icn [en] (" || &host || ")"
      set_fields(a)
end