Skip to content

Commit

Permalink
Refresh GitHubReadQueue when we get an authentation failure
Browse files Browse the repository at this point in the history
  • Loading branch information
bobbobbio committed Jan 1, 2025
1 parent 4342563 commit 6e42e63
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 58 deletions.
144 changes: 87 additions & 57 deletions crates/maelstrom-github/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@ pub trait QueueConnection {
async fn list(&self) -> Result<Vec<Artifact>>;
}

pub enum ReadResponse {
Data { data: Vec<u8>, etag: Etag },
NoData,
AuthenticationFailed,
}

#[allow(async_fn_in_trait)]
pub trait QueueBlob: Send + Sync + 'static {
async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<Option<(Vec<u8>, Etag)>>;
async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<ReadResponse>;
fn write(&self, data: Vec<u8>) -> impl Future<Output = Result<()>> + Send;
}

Expand All @@ -47,7 +53,7 @@ impl QueueConnection for GitHubClient {
}

impl QueueBlob for BlobClient {
async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<Option<(Vec<u8>, Etag)>> {
async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<ReadResponse> {
let mut builder = self.get().range(index..);

if let Some(etag) = etag {
Expand All @@ -64,7 +70,10 @@ impl QueueBlob for BlobClient {
match resp {
Ok(resp) => {
let msg = resp.data.collect().await?;
Ok(Some((msg.to_vec(), resp.blob.properties.etag)))
Ok(ReadResponse::Data {
data: msg.to_vec(),
etag: resp.blob.properties.etag,
})
}
Err(err) => {
use azure_core::{error::ErrorKind, StatusCode};
Expand All @@ -74,13 +83,19 @@ impl QueueBlob for BlobClient {
status: StatusCode::NotModified,
error_code: Some(error_code),
} if error_code == "ConditionNotMet" => {
return Ok(None);
return Ok(ReadResponse::NoData);
}
ErrorKind::HttpResponse {
status: StatusCode::RequestedRangeNotSatisfiable,
error_code: Some(error_code),
} if error_code == "InvalidRange" => {
return Ok(None);
return Ok(ReadResponse::NoData);
}
ErrorKind::HttpResponse {
status: StatusCode::Forbidden,
error_code: Some(error_code),
} if error_code == "AuthenticationFailed" => {
return Ok(ReadResponse::AuthenticationFailed);
}
_ => {}
}
Expand All @@ -104,37 +119,51 @@ enum MessageHeader {

const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(60);

pub struct GitHubReadQueue<BlobT = BlobClient> {
blob: BlobT,
pub struct GitHubReadQueue<ConnT: QueueConnection = GitHubClient> {
conn: Arc<ConnT>,
blob: ConnT::Blob,
index: usize,
etag: Option<Etag>,
pending: VecDeque<Option<Vec<u8>>>,
read_timeout: Duration,
backend_ids: BackendIds,
key: String,
}

impl<BlobT: QueueBlob> GitHubReadQueue<BlobT> {
async fn new<ConnT>(
conn: &ConnT,
impl<ConnT> GitHubReadQueue<ConnT>
where
ConnT: QueueConnection,
{
async fn new(
conn: Arc<ConnT>,
read_timeout: Duration,
backend_ids: BackendIds,
key: &str,
) -> Result<Self>
where
ConnT: QueueConnection<Blob = BlobT>,
{
let blob = conn.get_blob(backend_ids, key).await?;
) -> Result<Self> {
let blob = conn.get_blob(backend_ids.clone(), key).await?;
Ok(Self {
conn,
blob,
index: 0,
etag: None,
pending: Default::default(),
read_timeout,
backend_ids,
key: key.into(),
})
}

async fn maybe_read_msg(&mut self) -> Result<Option<Vec<u8>>> {
let Some((msg, etag)) = self.blob.read(self.index, &self.etag).await? else {
return Ok(None);
let (msg, etag) = match self.blob.read(self.index, &self.etag).await? {
ReadResponse::Data { data, etag } => (data, etag),
ReadResponse::NoData => return Ok(None),
ReadResponse::AuthenticationFailed => {
self.blob = self
.conn
.get_blob(self.backend_ids.clone(), &self.key)
.await?;
return Ok(None);
}
};

self.etag = Some(etag);
Expand Down Expand Up @@ -243,32 +272,29 @@ async fn wait_for_artifact(conn: &impl QueueConnection, key: &str) -> Result<()>
Ok(())
}

pub struct GitHubQueue<BlobT = BlobClient> {
read: GitHubReadQueue<BlobT>,
write: GitHubWriteQueue<BlobT>,
pub struct GitHubQueue<ConnT: QueueConnection = GitHubClient> {
read: GitHubReadQueue<ConnT>,
write: GitHubWriteQueue<ConnT::Blob>,
}

impl<BlobT: QueueBlob> GitHubQueue<BlobT> {
async fn new<ConnT>(
conn: &ConnT,
impl<ConnT> GitHubQueue<ConnT>
where
ConnT: QueueConnection,
{
async fn new(
conn: Arc<ConnT>,
read_timeout: Duration,
read_backend_ids: BackendIds,
read_key: &str,
write_key: &str,
) -> Result<Self>
where
ConnT: QueueConnection<Blob = BlobT>,
{
) -> Result<Self> {
Ok(Self {
write: GitHubWriteQueue::new(conn, read_timeout / 4, write_key).await?,
write: GitHubWriteQueue::new(&*conn, read_timeout / 4, write_key).await?,
read: GitHubReadQueue::new(conn, read_timeout, read_backend_ids, read_key).await?,
})
}

async fn maybe_connect<ConnT>(conn: &ConnT, id: &str) -> Result<Option<Self>>
where
ConnT: QueueConnection<Blob = BlobT>,
{
async fn maybe_connect(conn: Arc<ConnT>, id: &str) -> Result<Option<Self>> {
let artifacts = conn.list().await?;
if let Some(listener) = artifacts.iter().find(|a| a.name == format!("{id}-listen")) {
let Artifact {
Expand All @@ -278,10 +304,10 @@ impl<BlobT: QueueBlob> GitHubQueue<BlobT> {
let self_id = uuid::Uuid::new_v4().to_string();

let write_key = format!("{self_id}-{key}-up");
let write = GitHubWriteQueue::new(conn, DEFAULT_READ_TIMEOUT / 4, &write_key).await?;
let write = GitHubWriteQueue::new(&*conn, DEFAULT_READ_TIMEOUT / 4, &write_key).await?;

let read_key = format!("{self_id}-{key}-down");
wait_for_artifact(conn, &read_key).await?;
wait_for_artifact(&*conn, &read_key).await?;
let read =
GitHubReadQueue::new(conn, DEFAULT_READ_TIMEOUT, backend_ids.clone(), &read_key)
.await?;
Expand All @@ -292,12 +318,9 @@ impl<BlobT: QueueBlob> GitHubQueue<BlobT> {
}
}

pub async fn connect<ConnT>(conn: &ConnT, id: &str) -> Result<Self>
where
ConnT: QueueConnection<Blob = BlobT>,
{
pub async fn connect(conn: Arc<ConnT>, id: &str) -> Result<Self> {
loop {
if let Some(socket) = Self::maybe_connect(conn, id).await? {
if let Some(socket) = Self::maybe_connect(conn.clone(), id).await? {
return Ok(socket);
}
}
Expand All @@ -315,7 +338,7 @@ impl<BlobT: QueueBlob> GitHubQueue<BlobT> {
self.write.shut_down().await
}

pub fn into_split(self) -> (GitHubReadQueue<BlobT>, GitHubWriteQueue<BlobT>) {
pub fn into_split(self) -> (GitHubReadQueue<ConnT>, GitHubWriteQueue<ConnT::Blob>) {
(self.read, self.write)
}
}
Expand All @@ -340,7 +363,7 @@ where
})
}

async fn maybe_accept_one(&mut self) -> Result<Option<GitHubQueue<ConnT::Blob>>> {
async fn maybe_accept_one(&mut self) -> Result<Option<GitHubQueue<ConnT>>> {
let artifacts = self.conn.list().await?;
if let Some(connected) = artifacts.iter().find(|a| {
a.name.ends_with(&format!("{}-up", self.id)) && !self.accepted.contains(&a.name)
Expand All @@ -350,7 +373,7 @@ where
} = connected;
let key = name.strip_suffix("-up").unwrap();
let socket = GitHubQueue::new(
&*self.conn,
self.conn.clone(),
DEFAULT_READ_TIMEOUT,
backend_ids.clone(),
&format!("{key}-up"),
Expand All @@ -364,7 +387,7 @@ where
}
}

pub async fn accept_one(&mut self) -> Result<GitHubQueue<ConnT::Blob>> {
pub async fn accept_one(&mut self) -> Result<GitHubQueue<ConnT>> {
loop {
if let Some(socket) = self.maybe_accept_one().await? {
return Ok(socket);
Expand Down Expand Up @@ -394,6 +417,10 @@ mod tests {
fn len(&self) -> usize {
self.data.lock().unwrap().len()
}

fn data(&self) -> Vec<u8> {
self.data.lock().unwrap().clone()
}
}

fn b_ids() -> BackendIds {
Expand Down Expand Up @@ -451,7 +478,7 @@ mod tests {
}

impl QueueBlob for FakeBlob {
async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<Option<(Vec<u8>, Etag)>> {
async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<ReadResponse> {
use sha2::Digest as _;

tokio::task::yield_now().await;
Expand All @@ -466,14 +493,17 @@ mod tests {

if let Some(not_etag) = etag {
if not_etag == &actual_etag {
return Ok(None);
return Ok(ReadResponse::NoData);
}
}

if !data.is_empty() {
assert!(index < data.len());
}
Ok(Some((data[index..].to_vec(), actual_etag)))
Ok(ReadResponse::Data {
data: data[index..].to_vec(),
etag: actual_etag,
})
}

async fn write(&self, data: Vec<u8>) -> Result<()> {
Expand All @@ -491,7 +521,7 @@ mod tests {
async fn read_single_msg() {
let conn = FakeConnection::default();
let b = conn.create_blob("foo").await.unwrap();
let mut queue = GitHubReadQueue::new(&conn, SHORT_DURATION, b_ids(), "foo")
let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
.await
.unwrap();

Expand All @@ -509,7 +539,7 @@ mod tests {
async fn read_multiple_msgs() {
let conn = FakeConnection::default();
let b = conn.create_blob("foo").await.unwrap();
let mut queue = GitHubReadQueue::new(&conn, SHORT_DURATION, b_ids(), "foo")
let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
.await
.unwrap();

Expand All @@ -533,7 +563,7 @@ mod tests {
async fn read_multiple_msgs_interleaved() {
let conn = FakeConnection::default();
let b = conn.create_blob("foo").await.unwrap();
let mut queue = GitHubReadQueue::new(&conn, SHORT_DURATION, b_ids(), "foo")
let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
.await
.unwrap();

Expand All @@ -553,7 +583,7 @@ mod tests {
async fn read_ignores_keep_alive_msgs() {
let conn = FakeConnection::default();
let b = conn.create_blob("foo").await.unwrap();
let mut queue = GitHubReadQueue::new(&conn, SHORT_DURATION, b_ids(), "foo")
let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
.await
.unwrap();

Expand All @@ -578,7 +608,7 @@ mod tests {
async fn read_with_shutdown() {
let conn = FakeConnection::default();
let b = conn.create_blob("foo").await.unwrap();
let mut queue = GitHubReadQueue::new(&conn, SHORT_DURATION, b_ids(), "foo")
let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
.await
.unwrap();

Expand All @@ -600,7 +630,7 @@ mod tests {
async fn read_timeout() {
let conn = FakeConnection::default();
let _ = conn.create_blob("foo").await.unwrap();
let mut queue = GitHubReadQueue::new(&conn, SHORT_DURATION, b_ids(), "foo")
let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
.await
.unwrap();

Expand All @@ -618,7 +648,7 @@ mod tests {
expected.extend(sent);

let b = conn.get_blob(b_ids(), "foo").await.unwrap();
assert_eq!(b.read(0, &None).await.unwrap().unwrap().0, expected);
assert_eq!(b.data(), expected);
}

#[tokio::test]
Expand All @@ -630,7 +660,7 @@ mod tests {
let expected = bincode::serialize(&MessageHeader::Shutdown).unwrap();

let b = conn.get_blob(b_ids(), "foo").await.unwrap();
assert_eq!(b.read(0, &None).await.unwrap().unwrap().0, expected);
assert_eq!(b.data(), expected);
}

#[tokio::test]
Expand All @@ -643,7 +673,7 @@ mod tests {
drop(queue);

let b = conn.get_blob(b_ids(), "foo").await.unwrap();
let data = b.read(0, &None).await.unwrap().unwrap().0;
let data = b.data();
let mut cursor = &data[..];

let mut keep_alive_count = 0;
Expand All @@ -667,7 +697,7 @@ mod tests {
queue_b.write_msg(&b"hello"[..]).await.unwrap();
});

let mut queue_a = GitHubQueue::connect(&*conn, "foo").await.unwrap();
let mut queue_a = GitHubQueue::connect(conn, "foo").await.unwrap();
let msg = queue_a.read_msg().await.unwrap().unwrap();
assert_eq!(msg, b"hello");
}
Expand Down Expand Up @@ -696,7 +726,7 @@ mod tests {
}

async fn connector(client: GitHubClient) {
let mut sock = GitHubQueue::connect(&client, "foo").await.unwrap();
let mut sock = GitHubQueue::connect(Arc::new(client), "foo").await.unwrap();
while let Some(msg) = sock.read_msg().await.unwrap() {
assert_eq!(msg, b"ping");
sock.write_msg(&b"pong"[..]).await.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/maelstrom-worker/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl BrokerConnection for GitHubQueue {
log: &Logger,
) -> Result<(Self::Read, Self::Write)> {
let client = crate::github_client_factory()?;
let (read, mut write) = GitHubQueue::connect(&*client, "maelstrom-broker")
let (read, mut write) = GitHubQueue::connect(client, "maelstrom-broker")
.await
.map_err(|err| {
error!(log, "error connecting to broker"; "error" => %err);
Expand Down

0 comments on commit 6e42e63

Please sign in to comment.