From 7963dcf3c5223c90eaefdb10b0c8293473fb77cf Mon Sep 17 00:00:00 2001 From: FloatingGhost Date: Tue, 23 Aug 2022 12:08:26 +0100 Subject: [PATCH] assign user session --- lib/pleroma/helpers/auth_helper.ex | 13 +++++++++++++ lib/pleroma/web/o_auth/o_auth_controller.ex | 21 +++++++++++++++++++-- lib/pleroma/web/o_auth/token.ex | 9 +++++++++ lib/pleroma/web/o_auth/token/query.ex | 9 +++++++++ 4 files changed, 50 insertions(+), 2 deletions(-) diff --git a/lib/pleroma/helpers/auth_helper.ex b/lib/pleroma/helpers/auth_helper.ex index 13e4c8158..37765da4d 100644 --- a/lib/pleroma/helpers/auth_helper.ex +++ b/lib/pleroma/helpers/auth_helper.ex @@ -9,6 +9,7 @@ defmodule Pleroma.Helpers.AuthHelper do import Plug.Conn @oauth_token_session_key :oauth_token + @oauth_user_session_key :oauth_user @doc """ Skips OAuth permissions (scopes) checks, assigns nil `:token`. @@ -43,4 +44,16 @@ def put_session_token(%Conn{} = conn, token) when is_binary(token) do def delete_session_token(%Conn{} = conn) do delete_session(conn, @oauth_token_session_key) end + + def put_session_user(%Conn{} = conn, user) do + put_session(conn, @oauth_user_session_key, user) + end + + def delete_session_user(%Conn{} = conn) do + delete_session(conn, @oauth_user_session_key) + end + + def get_session_user(%Conn{} = conn) do + get_session(conn, @oauth_user_session_key) + end end diff --git a/lib/pleroma/web/o_auth/o_auth_controller.ex b/lib/pleroma/web/o_auth/o_auth_controller.ex index 358120fe6..45e99d35b 100644 --- a/lib/pleroma/web/o_auth/o_auth_controller.ex +++ b/lib/pleroma/web/o_auth/o_auth_controller.ex @@ -70,7 +70,21 @@ def authorize( end end - def authorize(%Plug.Conn{} = conn, params), do: do_authorize(conn, params) + def authorize(%Plug.Conn{} = conn, params) do + # if we have a user in the session, attempt to authenticate as them + # otherwise show the login form + with user_id <- AuthHelper.get_session_user(conn), + false <- is_nil(user_id), + %User{} = user <- User.get_cached_by_id(user_id), + %App{} = app <- Repo.get_by(App, client_id: params["client_id"]), + {:ok, %Token{} = token} <- Token.get_preeexisting_by_app_and_user(app, user) do + conn + |> assign(:token, token) + |> handle_existing_authorization(params) + else + _ -> do_authorize(conn, params) + end + end defp do_authorize(%Plug.Conn{} = conn, params) do app = Repo.get_by(App, client_id: params["client_id"]) @@ -148,7 +162,9 @@ def create_authorization(%Plug.Conn{assigns: %{user: %User{} = user}} = conn, pa def create_authorization(%Plug.Conn{} = conn, %{"authorization" => _} = params, opts) do with {:ok, auth, user} <- do_create_authorization(conn, params, opts[:user]), {:mfa_required, _, _, false} <- {:mfa_required, user, auth, MFA.require?(user)} do - after_create_authorization(conn, auth, params) + conn + |> AuthHelper.put_session_user(user.id) + |> after_create_authorization(auth, params) else error -> handle_create_authorization_error(conn, error, params) @@ -321,6 +337,7 @@ def token_exchange(%Plug.Conn{} = conn, params), do: bad_request(conn, params) def after_token_exchange(%Plug.Conn{} = conn, %{token: token} = view_params) do conn |> AuthHelper.put_session_token(token.token) + |> AuthHelper.put_session_user(token.user_id) |> json(OAuthView.render("token.json", view_params)) end diff --git a/lib/pleroma/web/o_auth/token.ex b/lib/pleroma/web/o_auth/token.ex index 9d69e9db4..686e6715b 100644 --- a/lib/pleroma/web/o_auth/token.ex +++ b/lib/pleroma/web/o_auth/token.ex @@ -70,6 +70,15 @@ def exchange_token(app, auth) do end end + def get_preeexisting_by_app_and_user(app, user) do + Query.get_by_app(app.id) + |> Query.get_by_user(user.id) + |> Query.get_unexpired() + |> Query.preload([:user]) + |> Query.limit(1) + |> Repo.find_resource() + end + defp put_token(changeset) do changeset |> change(%{token: Token.Utils.generate_token()}) diff --git a/lib/pleroma/web/o_auth/token/query.ex b/lib/pleroma/web/o_auth/token/query.ex index d16a759d8..1415191b7 100644 --- a/lib/pleroma/web/o_auth/token/query.ex +++ b/lib/pleroma/web/o_auth/token/query.ex @@ -38,6 +38,15 @@ def get_by_user(query \\ Token, user_id) do from(q in query, where: q.user_id == ^user_id) end + def get_unexpired(query) do + now = NaiveDateTime.utc_now() + from(q in query, where: q.valid_until > ^now) + end + + def limit(query, limit) do + from(q in query, limit: ^limit) + end + @spec preload(query, any) :: query def preload(query \\ Token, assoc_preload \\ [])