Disconnect streaming sessions when token is revoked

Use Websockex to replace websocket_client

Test that server will disconnect websocket upon token revocation

Lint

Execute session disconnect in background

Refactor streamer test

allow multi-streams

rebase websocket change
This commit is contained in:
Tusooa Zhu 2022-08-19 13:19:38 -04:00 committed by FloatingGhost
parent 772c209914
commit 95e4018c1a
9 changed files with 192 additions and 31 deletions

View file

@ -63,7 +63,8 @@ def start(_type, _args) do
Pleroma.Repo, Pleroma.Repo,
Config.TransferTask, Config.TransferTask,
Pleroma.Emoji, Pleroma.Emoji,
Pleroma.Web.Plugs.RateLimiter.Supervisor Pleroma.Web.Plugs.RateLimiter.Supervisor,
{Task.Supervisor, name: Pleroma.TaskSupervisor}
] ++ ] ++
cachex_children() ++ cachex_children() ++
http_children() ++ http_children() ++

View file

@ -59,7 +59,7 @@ def websocket_init(state) do
"#{__MODULE__} accepted websocket connection for user #{(state.user || %{id: "anonymous"}).id}, topic #{state.topic}" "#{__MODULE__} accepted websocket connection for user #{(state.user || %{id: "anonymous"}).id}, topic #{state.topic}"
) )
Streamer.add_socket(state.topic, state.user) Streamer.add_socket(state.topic, state.oauth_token)
{:ok, %{state | timer: timer()}} {:ok, %{state | timer: timer()}}
end end
@ -139,6 +139,10 @@ def websocket_info(:tick, state) do
{:reply, :ping, %{state | timer: nil, count: 0}, :hibernate} {:reply, :ping, %{state | timer: nil, count: 0}, :hibernate}
end end
def websocket_info(:close, state) do
{:stop, state}
end
# State can be `[]` only in case we terminate before switching to websocket, # State can be `[]` only in case we terminate before switching to websocket,
# we already log errors for these cases in `init/1`, so just do nothing here # we already log errors for these cases in `init/1`, so just do nothing here
def terminate(_reason, _req, []), do: :ok def terminate(_reason, _req, []), do: :ok

View file

@ -21,6 +21,18 @@ def revoke(%App{} = app, %{"token" => token} = _attrs) do
@doc "Revokes access token" @doc "Revokes access token"
@spec revoke(Token.t()) :: {:ok, Token.t()} | {:error, Ecto.Changeset.t()} @spec revoke(Token.t()) :: {:ok, Token.t()} | {:error, Ecto.Changeset.t()}
def revoke(%Token{} = token) do def revoke(%Token{} = token) do
Repo.delete(token) with {:ok, token} <- Repo.delete(token) do
Task.Supervisor.start_child(
Pleroma.TaskSupervisor,
Pleroma.Web.Streamer,
:close_streams_by_oauth_token,
[token],
restart: :transient
)
{:ok, token}
else
result -> result
end
end end
end end

View file

@ -36,7 +36,7 @@ def registry, do: @registry
{:ok, topic :: String.t()} | {:error, :bad_topic} | {:error, :unauthorized} {:ok, topic :: String.t()} | {:error, :bad_topic} | {:error, :unauthorized}
def get_topic_and_add_socket(stream, user, oauth_token, params \\ %{}) do def get_topic_and_add_socket(stream, user, oauth_token, params \\ %{}) do
with {:ok, topic} <- get_topic(stream, user, oauth_token, params) do with {:ok, topic} <- get_topic(stream, user, oauth_token, params) do
add_socket(topic, user) add_socket(topic, oauth_token)
end end
end end
@ -124,10 +124,10 @@ def get_topic(_stream, _user, _oauth_token, _params) do
end end
@doc "Registers the process for streaming. Use `get_topic/3` to get the full authorized topic." @doc "Registers the process for streaming. Use `get_topic/3` to get the full authorized topic."
def add_socket(topic, user) do def add_socket(topic, oauth_token) do
if should_env_send?() do if should_env_send?() do
auth? = if user, do: true oauth_token_id = if oauth_token, do: oauth_token.id, else: false
Registry.register(@registry, topic, auth?) Registry.register(@registry, topic, oauth_token_id)
end end
{:ok, topic} {:ok, topic}
@ -311,6 +311,22 @@ defp thread_containment(activity, user) do
end end
end end
def close_streams_by_oauth_token(oauth_token) do
if should_env_send?() do
Registry.select(
@registry,
[
{
{:"$1", :"$2", :"$3"},
[{:==, :"$3", oauth_token.id}],
[:"$2"]
}
]
)
|> Enum.each(fn pid -> send(pid, :close) end)
end
end
# In test environement, only return true if the registry is started. # In test environement, only return true if the registry is started.
# In benchmark environment, returns false. # In benchmark environment, returns false.
# In any other environment, always returns true. # In any other environment, always returns true.

View file

@ -206,7 +206,7 @@ defp deps do
# temporary downgrade for excoveralls, hackney until hackney max_connections bug will be fixed # temporary downgrade for excoveralls, hackney until hackney max_connections bug will be fixed
{:excoveralls, "0.12.3", only: :test}, {:excoveralls, "0.12.3", only: :test},
{:mox, "~> 1.0", only: :test}, {:mox, "~> 1.0", only: :test},
{:websocket_client, git: "https://github.com/jeremyong/websocket_client.git", only: :test} {:websockex, "~> 0.4.3", only: :test}
] ++ oauth_deps() ] ++ oauth_deps()
end end

View file

@ -120,5 +120,5 @@
"unsafe": {:hex, :unsafe, "1.0.1", "a27e1874f72ee49312e0a9ec2e0b27924214a05e3ddac90e91727bc76f8613d8", [:mix], [], "hexpm", "6c7729a2d214806450d29766abc2afaa7a2cbecf415be64f36a6691afebb50e5"}, "unsafe": {:hex, :unsafe, "1.0.1", "a27e1874f72ee49312e0a9ec2e0b27924214a05e3ddac90e91727bc76f8613d8", [:mix], [], "hexpm", "6c7729a2d214806450d29766abc2afaa7a2cbecf415be64f36a6691afebb50e5"},
"vex": {:hex, :vex, "0.9.0", "613ea5eb3055662e7178b83e25b2df0975f68c3d8bb67c1645f0573e1a78d606", [:mix], [], "hexpm", "c69fff44d5c8aa3f1faee71bba1dcab05dd36364c5a629df8bb11751240c857f"}, "vex": {:hex, :vex, "0.9.0", "613ea5eb3055662e7178b83e25b2df0975f68c3d8bb67c1645f0573e1a78d606", [:mix], [], "hexpm", "c69fff44d5c8aa3f1faee71bba1dcab05dd36364c5a629df8bb11751240c857f"},
"web_push_encryption": {:hex, :web_push_encryption, "0.3.1", "76d0e7375142dfee67391e7690e89f92578889cbcf2879377900b5620ee4708d", [:mix], [{:httpoison, "~> 1.0", [hex: :httpoison, repo: "hexpm", optional: false]}, {:jose, "~> 1.11.1", [hex: :jose, repo: "hexpm", optional: false]}], "hexpm", "4f82b2e57622fb9337559058e8797cb0df7e7c9790793bdc4e40bc895f70e2a2"}, "web_push_encryption": {:hex, :web_push_encryption, "0.3.1", "76d0e7375142dfee67391e7690e89f92578889cbcf2879377900b5620ee4708d", [:mix], [{:httpoison, "~> 1.0", [hex: :httpoison, repo: "hexpm", optional: false]}, {:jose, "~> 1.11.1", [hex: :jose, repo: "hexpm", optional: false]}], "hexpm", "4f82b2e57622fb9337559058e8797cb0df7e7c9790793bdc4e40bc895f70e2a2"},
"websocket_client": {:git, "https://github.com/jeremyong/websocket_client.git", "9a6f65d05ebf2725d62fb19262b21f1805a59fbf", []}, "websockex": {:hex, :websockex, "0.4.3", "92b7905769c79c6480c02daacaca2ddd49de936d912976a4d3c923723b647bf0", [:mix], [], "hexpm", "95f2e7072b85a3a4cc385602d42115b73ce0b74a9121d0d6dbbf557645ac53e4"},
} }

View file

@ -34,15 +34,20 @@ def start_socket(qs \\ nil, headers \\ []) do
test "allows multi-streams" do test "allows multi-streams" do
capture_log(fn -> capture_log(fn ->
assert {:ok, _} = start_socket() assert {:ok, _} = start_socket()
assert {:error, {404, _}} = start_socket("?stream=ncjdk")
assert {:error, %WebSockex.RequestError{code: 404, message: "Not Found"}} =
start_socket("?stream=ncjdk")
Process.sleep(30) Process.sleep(30)
end) end)
end end
test "requires authentication and a valid token for protected streams" do test "requires authentication and a valid token for protected streams" do
capture_log(fn -> capture_log(fn ->
assert {:error, {401, _}} = start_socket("?stream=user&access_token=aaaaaaaaaaaa") assert {:error, %WebSockex.RequestError{code: 401}} =
assert {:error, {401, _}} = start_socket("?stream=user") start_socket("?stream=user&access_token=aaaaaaaaaaaa")
assert {:error, %WebSockex.RequestError{code: 401}} = start_socket("?stream=user")
Process.sleep(30) Process.sleep(30)
end) end)
end end
@ -91,7 +96,7 @@ test "receives well formatted events" do
{:ok, token} = OAuth.Token.exchange_token(app, auth) {:ok, token} = OAuth.Token.exchange_token(app, auth)
%{user: user, token: token} %{app: app, user: user, token: token}
end end
test "accepts valid tokens", state do test "accepts valid tokens", state do
@ -102,7 +107,7 @@ test "accepts the 'user' stream", %{token: token} = _state do
assert {:ok, _} = start_socket("?stream=user&access_token=#{token.token}") assert {:ok, _} = start_socket("?stream=user&access_token=#{token.token}")
capture_log(fn -> capture_log(fn ->
assert {:error, {401, _}} = start_socket("?stream=user") assert {:error, %WebSockex.RequestError{code: 401}} = start_socket("?stream=user")
Process.sleep(30) Process.sleep(30)
end) end)
end end
@ -111,7 +116,9 @@ test "accepts the 'user:notification' stream", %{token: token} = _state do
assert {:ok, _} = start_socket("?stream=user:notification&access_token=#{token.token}") assert {:ok, _} = start_socket("?stream=user:notification&access_token=#{token.token}")
capture_log(fn -> capture_log(fn ->
assert {:error, {401, _}} = start_socket("?stream=user:notification") assert {:error, %WebSockex.RequestError{code: 401}} =
start_socket("?stream=user:notification")
Process.sleep(30) Process.sleep(30)
end) end)
end end
@ -120,11 +127,27 @@ test "accepts valid token on Sec-WebSocket-Protocol header", %{token: token} do
assert {:ok, _} = start_socket("?stream=user", [{"Sec-WebSocket-Protocol", token.token}]) assert {:ok, _} = start_socket("?stream=user", [{"Sec-WebSocket-Protocol", token.token}])
capture_log(fn -> capture_log(fn ->
assert {:error, {401, _}} = assert {:error, %WebSockex.RequestError{code: 401}} =
start_socket("?stream=user", [{"Sec-WebSocket-Protocol", "I am a friend"}]) start_socket("?stream=user", [{"Sec-WebSocket-Protocol", "I am a friend"}])
Process.sleep(30) Process.sleep(30)
end) end)
end end
test "disconnect when token is revoked", %{app: app, user: user, token: token} do
assert {:ok, _} = start_socket("?stream=user:notification&access_token=#{token.token}")
assert {:ok, _} = start_socket("?stream=user&access_token=#{token.token}")
{:ok, auth} = OAuth.Authorization.create_authorization(app, user)
{:ok, token2} = OAuth.Token.exchange_token(app, auth)
assert {:ok, _} = start_socket("?stream=user&access_token=#{token2.token}")
OAuth.Token.Strategy.Revoke.revoke(token)
assert_receive {:close, _}
assert_receive {:close, _}
refute_receive {:close, _}
end
end end
end end

View file

@ -760,4 +760,105 @@ test "it sends conversation update to the 'direct' stream when a message is dele
assert last_status["id"] == to_string(create_activity.id) assert last_status["id"] == to_string(create_activity.id)
end end
end end
describe "stop streaming if token got revoked" do
setup do
child_proc = fn start, finalize ->
fn ->
start.()
receive do
{StreamerTest, :ready} ->
assert_receive {:render_with_user, _, "update.json", _, _}
receive do
{StreamerTest, :revoked} -> finalize.()
end
end
end
end
starter = fn user, token ->
fn -> Streamer.get_topic_and_add_socket("user", user, token) end
end
hit = fn -> assert_receive :close end
miss = fn -> refute_receive :close end
send_all = fn tasks, thing -> Enum.each(tasks, &send(&1.pid, thing)) end
%{
child_proc: child_proc,
starter: starter,
hit: hit,
miss: miss,
send_all: send_all
}
end
test "do not revoke other tokens", %{
child_proc: child_proc,
starter: starter,
hit: hit,
miss: miss,
send_all: send_all
} do
%{user: user, token: token} = oauth_access(["read"])
%{token: token2} = oauth_access(["read"], user: user)
%{user: user2, token: user2_token} = oauth_access(["read"])
post_user = insert(:user)
CommonAPI.follow(user, post_user)
CommonAPI.follow(user2, post_user)
tasks = [
Task.async(child_proc.(starter.(user, token), hit)),
Task.async(child_proc.(starter.(user, token2), miss)),
Task.async(child_proc.(starter.(user2, user2_token), miss))
]
{:ok, _} =
CommonAPI.post(post_user, %{
status: "hi"
})
send_all.(tasks, {StreamerTest, :ready})
Pleroma.Web.OAuth.Token.Strategy.Revoke.revoke(token)
send_all.(tasks, {StreamerTest, :revoked})
Enum.each(tasks, &Task.await/1)
end
test "revoke all streams for this token", %{
child_proc: child_proc,
starter: starter,
hit: hit,
send_all: send_all
} do
%{user: user, token: token} = oauth_access(["read"])
post_user = insert(:user)
CommonAPI.follow(user, post_user)
tasks = [
Task.async(child_proc.(starter.(user, token), hit)),
Task.async(child_proc.(starter.(user, token), hit))
]
{:ok, _} =
CommonAPI.post(post_user, %{
status: "hi"
})
send_all.(tasks, {StreamerTest, :ready})
Pleroma.Web.OAuth.Token.Strategy.Revoke.revoke(token)
send_all.(tasks, {StreamerTest, :revoked})
Enum.each(tasks, &Task.await/1)
end
end
end end

View file

@ -5,18 +5,17 @@
defmodule Pleroma.Integration.WebsocketClient do defmodule Pleroma.Integration.WebsocketClient do
# https://github.com/phoenixframework/phoenix/blob/master/test/support/websocket_client.exs # https://github.com/phoenixframework/phoenix/blob/master/test/support/websocket_client.exs
use WebSockex
@doc """ @doc """
Starts the WebSocket server for given ws URL. Received Socket.Message's Starts the WebSocket server for given ws URL. Received Socket.Message's
are forwarded to the sender pid are forwarded to the sender pid
""" """
def start_link(sender, url, headers \\ []) do def start_link(sender, url, headers \\ []) do
:crypto.start() WebSockex.start_link(
:ssl.start() url,
:websocket_client.start_link(
String.to_charlist(url),
__MODULE__, __MODULE__,
[sender], %{sender: sender},
extra_headers: headers extra_headers: headers
) )
end end
@ -36,27 +35,32 @@ def send_text(server_pid, msg) do
end end
@doc false @doc false
def init([sender], _conn_state) do @impl true
{:ok, %{sender: sender}} def handle_frame(frame, state) do
end
@doc false
def websocket_handle(frame, _conn_state, state) do
send(state.sender, frame) send(state.sender, frame)
{:ok, state} {:ok, state}
end end
@impl true
def handle_disconnect(conn_status, state) do
send(state.sender, {:close, conn_status})
{:ok, state}
end
@doc false @doc false
def websocket_info({:text, msg}, _conn_state, state) do @impl true
def handle_info({:text, msg}, state) do
{:reply, {:text, msg}, state} {:reply, {:text, msg}, state}
end end
def websocket_info(:close, _conn_state, _state) do @impl true
def handle_info(:close, _state) do
{:close, <<>>, "done"} {:close, <<>>, "done"}
end end
@doc false @doc false
def websocket_terminate(_reason, _conn_state, _state) do @impl true
def terminate(_reason, _state) do
:ok :ok
end end
end end