# Copyright 2020-2023 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Author: Aniek Roelofs, Keegan Smith
import os
from unittest import TestCase
from unittest.mock import patch, call
import pendulum
from airflow.utils.state import State
from airflow.models.connection import Connection
import vcr
from oaebu_workflows.config import test_fixtures_folder
from oaebu_workflows.oaebu_partners import partner_from_str
from oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope import (
UclDiscoveryTelescope,
get_isbn_eprint_mappings,
download_discovery_stats,
transform_discovery_stats,
)
from observatory.platform.api import get_dataset_releases
from observatory.platform.observatory_config import Workflow
from observatory.platform.bigquery import bq_table_id
from observatory.platform.gcs import gcs_blob_name_from_path
from observatory.platform.observatory_environment import (
ObservatoryEnvironment,
ObservatoryTestCase,
find_free_port,
load_and_parse_json,
)
[docs]class TestUclDiscoveryTelescope(ObservatoryTestCase):
"""Tests for the Ucl Discovery telescope"""
def __init__(self, *args, **kwargs):
"""Constructor which sets up variables used by tests."""
super(TestUclDiscoveryTelescope, self).__init__(*args, **kwargs)
self.project_id = os.getenv("TEST_GCP_PROJECT_ID")
self.data_location = os.getenv("TEST_GCP_DATA_LOCATION")
fixtures_folder = test_fixtures_folder(workflow_module="ucl_discovery_telescope")
self.download_cassette = os.path.join(fixtures_folder, "download_cassette.yaml")
self.test_table = os.path.join(fixtures_folder, "test_table.json")
[docs] def test_dag_structure(self):
"""Test that the UCL Discovery DAG has the correct structure."""
dag = UclDiscoveryTelescope(
dag_id="Test_Dag", cloud_workspace=self.fake_cloud_workspace, sheet_id="foo"
).make_dag()
self.assert_dag_structure(
{
"check_dependencies": ["download"],
"download": ["upload_downloaded"],
"upload_downloaded": ["transform"],
"transform": ["upload_transformed"],
"upload_transformed": ["bq_load"],
"bq_load": ["add_new_dataset_releases"],
"add_new_dataset_releases": ["cleanup"],
"cleanup": [],
},
dag,
)
[docs] def test_dag_load(self):
"""Test that the UCL Discovery DAG can be loaded from a DAG bag."""
env = ObservatoryEnvironment(
workflows=[
Workflow(
dag_id="ucl_discovery",
name="UCL Discovery Telescope",
class_name="oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.UclDiscoveryTelescope",
cloud_workspace=self.fake_cloud_workspace,
kwargs=dict(sheet_id="foo"),
)
]
)
with env.create():
self.assert_dag_load_from_config("ucl_discovery")
[docs] def test_telescope(self):
"""Test the UCL Discovery telescope end to end."""
# Setup Observatory environment
env = ObservatoryEnvironment(
self.project_id, self.data_location, api_host="localhost", api_port=find_free_port()
)
# Setup Telescope
data_partner = partner_from_str("ucl_discovery")
data_partner.bq_dataset_id = env.add_dataset()
telescope = UclDiscoveryTelescope(
dag_id="ucl_discovery",
cloud_workspace=env.cloud_workspace,
sheet_id="foo",
data_partner=data_partner,
max_threads=1,
)
dag = telescope.make_dag()
execution_date = pendulum.datetime(year=2023, month=6, day=1)
# Create the Observatory environment and run tests
with env.create(), env.create_dag_run(dag, execution_date):
# env.add_connection(Connection(conn_id=telescope.oaebu_service_account_conn_id))
# Mock return values of download function
interval_start = pendulum.instance(env.dag_run.data_interval_start)
sheet_return = [
["ISBN13", "discovery_eprintid", "date", "title_list_title"],
["ISBN_1", "eprint_id1", interval_start.add(days=10).format("YYYYMMDD"), "title1"],
["ISBN_2", "", interval_start.add(days=10).format("YYYYMMDD"), "title2"], # should be ignored
["ISBN_3", "eprint_id3", interval_start.add(years=1).format("YYYYMMDD"), "title3"], # should be ignored
["", "eprint_id4", interval_start.add(days=10).format("YYYYMMDD"), "title4"], # should be ignored
["ISBN_5", "eprint_id5", interval_start.subtract(months=5).format("YYYYMMDD"), "title5"],
]
conn = Connection(
conn_id="oaebu_service_account",
uri=f"google-cloud-platform://?type=service_account&private_key_id=private_key_id"
f"&private_key=private_key"
f"&client_email=client_email"
f"&client_id=client_id"
f"&token_uri=token_uri",
)
env.add_connection(conn)
############################
### Main telescope tasks ###
############################
# Test that all dependencies are specified: no error should be thrown
ti = env.run_task(telescope.check_dependencies.__name__)
# download
cassette = vcr.VCR(record_mode="none")
sa_patch = patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.service_account")
conn_patch = patch(
"oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.BaseHook.get_connection"
)
build_patch = patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.discovery.build")
with sa_patch, conn_patch, build_patch as mock_build, cassette.use_cassette(self.download_cassette):
mock_service = mock_build.return_value.spreadsheets.return_value.values.return_value.get.return_value
mock_service.execute.return_value = {"values": sheet_return}
ti = env.run_task(telescope.download.__name__)
self.assertEqual(ti.state, State.SUCCESS)
# upload_downloaded
ti = env.run_task(telescope.upload_downloaded.__name__)
self.assertEqual(ti.state, State.SUCCESS)
# transform
with sa_patch, conn_patch, build_patch as mock_build:
mock_service = mock_build.return_value.spreadsheets.return_value.values.return_value.get.return_value
mock_service.execute.return_value = {"values": sheet_return}
ti = env.run_task(telescope.transform.__name__)
self.assertEqual(ti.state, State.SUCCESS)
# upload_transformed
ti = env.run_task(telescope.upload_transformed.__name__)
self.assertEqual(ti.state, State.SUCCESS)
# bq_load
ti = env.run_task(telescope.bq_load.__name__)
self.assertEqual(ti.state, State.SUCCESS)
##############################################
### Create the release and make assertions ###
##############################################
release = telescope.make_release(
run_id=env.dag_run.run_id,
data_interval_start=pendulum.parse(str(env.dag_run.data_interval_start)),
data_interval_end=pendulum.parse(str(env.dag_run.data_interval_end)),
)
# Download
self.assertTrue(os.path.exists(release.download_country_path))
self.assertTrue(os.path.exists(release.download_totals_path))
# Upload Downloaded
download_country_blob = gcs_blob_name_from_path(release.download_country_path)
self.assert_blob_integrity(env.download_bucket, download_country_blob, release.download_country_path)
download_totals_blob = gcs_blob_name_from_path(release.download_totals_path)
self.assert_blob_integrity(env.download_bucket, download_totals_blob, release.download_totals_path)
# Transform
self.assertTrue(os.path.exists(release.transform_path))
# Upload Transform
self.assert_blob_integrity(
env.transform_bucket, gcs_blob_name_from_path(release.transform_path), release.transform_path
)
# Bigquery load
table_id = bq_table_id(
telescope.cloud_workspace.project_id,
telescope.data_partner.bq_dataset_id,
telescope.data_partner.bq_table_name,
)
self.assert_table_integrity(table_id, 2)
self.assert_table_content(
table_id, load_and_parse_json(self.test_table, date_fields="release_date"), "ISBN"
)
###################
### Final tasks ###
###################
# Add_dataset_release_task
dataset_releases = get_dataset_releases(dag_id=telescope.dag_id, dataset_id=telescope.api_dataset_id)
self.assertEqual(len(dataset_releases), 0)
ti = env.run_task(telescope.add_new_dataset_releases.__name__)
self.assertEqual(ti.state, State.SUCCESS)
dataset_releases = get_dataset_releases(dag_id=telescope.dag_id, dataset_id=telescope.api_dataset_id)
self.assertEqual(len(dataset_releases), 1)
# Test that all telescope data deleted
ti = env.run_task(telescope.cleanup.__name__)
self.assertEqual(ti.state, State.SUCCESS)
self.assert_cleanup(release.workflow_folder)
[docs]class TestGetIsbnEprintMappings(TestCase):
"""Tests for the get_isbn_eprint_mappings function"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Set the cutoff date for the tests
self.cutoff_date = pendulum.datetime(year=2023, month=6, day=30)
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.service_account")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.BaseHook.get_connection")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.discovery.build")
[docs] def test_get_isbn_eprint_mappings(self, mock_build, mock_get_connection, mock_sa):
# Mock the Google Sheets API response
sheet_contents = [
["ISBN13", "discovery_eprintid", "date", "title_list_title"],
["111", "eprint_1", "2023-08-01", "title1"], # past cutoff, should be ignored
["222", "eprint_2", "2023-06-01", "title2"],
["333", "eprint_3", "2023-07-01", "title3"], # past cutoff, should be ignored
["444", "eprint_4", "2023-06-30", "title4"],
]
mock_service = mock_build.return_value.spreadsheets.return_value.values.return_value.get.return_value
mock_service.execute.return_value = {"values": sheet_contents}
# Call the function to test
mappings = get_isbn_eprint_mappings("sheet_id", "service_account_conn_id", self.cutoff_date)
# Assert that the returned mappings match the expected mappings
expected_mappings = {
"eprint_2": {"ISBN13": "222", "title": "title2"},
"eprint_4": {"ISBN13": "444", "title": "title4"},
}
self.assertEqual(mappings, expected_mappings)
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.service_account")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.BaseHook.get_connection")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.discovery.build")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.service_account")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.BaseHook.get_connection")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.discovery.build")
[docs] def test_empty_sheet(self, mock_build, mock_get_connection, mock_sa):
# Mock the Google Sheets API response with an empty sheet
mock_build.return_value.spreadsheets.return_value.values.return_value.get.return_value.execute.return_value = {}
with self.assertRaisesRegex(ValueError, "No content found"):
get_isbn_eprint_mappings("sheet_id", "service_account_conn_id", self.cutoff_date)
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.service_account")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.BaseHook.get_connection")
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.discovery.build")
[docs] def test_missing_values(self, mock_build, mock_get_connection, mock_sa):
# Mock the Google Sheets API response with a missing value
sheet_contents = [
["ISBN13", "discovery_eprintid", "date", "title_list_title"],
["111", "", "2023-06-01", "title1"], # eprint ID missing
["", "eprint_2", "2023-06-01", "title2"], # ISBN missing
["333", "eprint_3", "2023-06-01", ""], # Title missing, should still pass
]
mock_service = mock_build.return_value.spreadsheets.return_value.values.return_value.get.return_value
mock_service.execute.return_value = {"values": sheet_contents}
mappings = get_isbn_eprint_mappings("sheet_id", "service_account_conn_id", self.cutoff_date)
expected_mappings = {"eprint_3": {"ISBN13": "333", "title": ""}}
self.assertEqual(mappings, expected_mappings)
[docs]class TestDownloadDiscoveryStats(TestCase):
"""Tests for the download_discovery_stats function"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Set the cutoff date for the tests
self.start_date = pendulum.datetime(2022, 1, 1)
self.end_date = pendulum.datetime(2022, 1, 31)
self.start_formatted = self.start_date.format("YYYYMMDD")
self.end_formatted = self.end_date.format("YYYYMMDD")
self.eprint_id = "12345"
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.retry_get_url")
[docs] def test_download_discovery_stats(self, mock_retry_get_url):
"""Test the download_discovery_stats function works with correct inputs"""
expected_countries_url = (
"https://discovery.ucl.ac.uk/cgi/stats/get"
f"?from={self.start_date.format('YYYYMMDD')}&to={self.end_date.format('YYYYMMDD')}"
f"&irs2report=eprint&set_name=eprint&set_value={self.eprint_id}&datatype=countries&top=countries"
"&view=Table&limit=all&export=JSON"
)
expected_totals_url = (
"https://discovery.ucl.ac.uk/cgi/stats/get"
f"?from={self.start_date.format('YYYYMMDD')}&to={self.end_date.format('YYYYMMDD')}"
f"&irs2report=eprint&set_name=eprint&set_value={self.eprint_id}&datatype=downloads&graph_type=column"
"&view=Google%3A%3AGraph&date_resolution=month&title=Download+activity+-+last+12+months&export=JSON"
)
http_returns = [
{"timescale": {"from": self.start_formatted, "to": self.end_formatted}, "set": {"value": self.eprint_id}},
{"timescale": {"from": self.start_formatted, "to": self.end_formatted}, "set": {"value": self.eprint_id}},
]
mock_retry_get_url.return_value.json.side_effect = http_returns
# Check that the correct data is returned
result = download_discovery_stats(self.eprint_id, self.start_date, self.end_date)
# Check that constructed urls are correct
expected_calls = [call(expected_countries_url), call().json(), call(expected_totals_url), call().json()]
mock_retry_get_url.assert_has_calls(expected_calls)
# Check that returned results are correct
self.assertEqual(result[0], http_returns[0])
self.assertEqual(result[1], http_returns[1])
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.retry_get_url")
[docs] def test_download_discovery_stats_invalid_timescale(self, mock_retry_get_url):
"""Check if exceptions raised when timescale is inconsistent with inputs"""
mock_retry_get_url.return_value.json.side_effect = [
{"timescale": {"from": "19700101", "to": "19700130"}, "set": {"value": self.eprint_id}},
{"timescale": {"from": "19700101", "to": "19700130"}, "set": {"value": self.eprint_id}},
]
self.assertRaisesRegex(
ValueError, "timescale", download_discovery_stats, self.eprint_id, self.start_date, self.end_date
)
@patch("oaebu_workflows.ucl_discovery_telescope.ucl_discovery_telescope.retry_get_url")
[docs] def test_download_discovery_stats_invalid_eprint_id(self, mock_retry_get_url):
"""Check if exceptions raised when eprint ID is inconsistent with inputs"""
mock_retry_get_url.return_value.json.side_effect = [
{"timescale": {"from": self.start_formatted, "to": self.end_formatted}, "set": {"value": "67890"}},
{"timescale": {"from": self.start_formatted, "to": self.end_formatted}, "set": {"value": "67890"}},
]
self.assertRaisesRegex(
ValueError, "eprint ID", download_discovery_stats, self.eprint_id, self.start_date, self.end_date
)