Skip to content

Commit

Permalink
rebase code generation wrappers refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyRyabinin committed Jan 23, 2025
1 parent d40845c commit 99c034e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 11 deletions.
3 changes: 2 additions & 1 deletion tools/scripts/codegen/legacy_c2j_cpp_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def generate(self, executor: ProcessPoolExecutor, max_workers: int, args: dict)
return -1

print(f"Code generation done, (re)generated {len(done)} packages.") # Including defaults and partitions
return 0

def _init_common_java_cli(self,
service_name: str,
Expand Down Expand Up @@ -172,7 +173,7 @@ def _init_common_java_cli(self,
run_command += ["--service", service_name]
run_command += ["--outputfile", output_filename]

if service_name in SMITHY_SUPPORTED_CLIENTS:
if service_name in SMITHY_SUPPORTED_CLIENTS or model_files.use_smithy:
run_command += ["--use-smithy-client"]

for key, val in kwargs.items():
Expand Down
58 changes: 53 additions & 5 deletions tools/scripts/codegen/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
A set of utils to go through c2j models and their corresponding endpoint rules
"""
import datetime
import json
import os
import re

Expand All @@ -18,6 +19,29 @@
"transcribe-streaming": "transcribestreaming",
"streams.dynamodb": "dynamodbstreams"}

SMITHY_EXCLUSION_CLIENTS = {
# multi auth
"eventbridge"
, "cloudfront-keyvaluestore"
, "cognito-identity"
, "cognito-idp"
# customization
, "machinelearning"
, "apigatewayv2"
, "apigateway"
, "eventbridge"
, "glacier"
, "lambda"
, "polly"
, "sqs"
# bearer token
# ,"codecatalyst"
# bidirectional streaming
, "lexv2-runtime"
, "qbusiness"
, "transcribestreaming"
}

# Regexp to parse C2J model filename to extract service name and date version
SERVICE_MODEL_FILENAME_PATTERN = re.compile(
"^"
Expand All @@ -29,12 +53,13 @@

class ServiceModel(object):
# A helper class to store C2j model info and metadata (endpoint rules and tests)
def __init__(self, service_id, c2j_model, endpoint_rule_set, endpoint_tests):
def __init__(self, service_id: str, c2j_model: str, endpoint_rule_set: str, endpoint_tests: str, use_smithy: bool):
self.service_id = service_id # For debugging purposes, not used atm
# only filenames, no filesystem path
self.c2j_model = c2j_model
self.endpoint_rule_set = endpoint_rule_set
self.endpoint_tests = endpoint_tests
self.use_smithy = use_smithy


class ModelUtils(object):
Expand Down Expand Up @@ -113,7 +138,8 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict:

# fetch endpoint-rules filename which is based on ServiceId in c2j models:
try:
service_name_to_model_filename[key] = ModelUtils._build_service_model(endpoint_rules_dir,
service_name_to_model_filename[key] = ModelUtils._build_service_model(models_dir,
endpoint_rules_dir,
model_file_date[0])

if key == "s3":
Expand All @@ -125,7 +151,8 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict:
service_name_to_model_filename[key] = ServiceModel(service_id=key,
c2j_model=model_file_date[0],
endpoint_rule_set=None,
endpoint_tests=None)
endpoint_tests=None,
use_smithy=False)
if missing:
# TODO: re-enable with endpoints introduction
# print(f"Missing endpoints for services: {missing}")
Expand All @@ -137,7 +164,25 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict:
return service_name_to_model_filename

@staticmethod
def _build_service_model(endpoint_rules_dir: str, c2j_model_filename) -> ServiceModel:
def is_smithy_enabled(service_id, models_dir, c2j_model_filename):
"""Return true if given service id and c2j model file should enable smithy client generation path
:param service_id:
:param models_dir:
:param c2j_model_filename:
:return:
"""
use_smithy = False
if service_id not in SMITHY_EXCLUSION_CLIENTS:
with open(models_dir + "/" + c2j_model_filename, 'r') as json_file:
model = json.load(json_file)
model_protocol = model.get("metadata", dict()).get("protocol", "UNKNOWN_PROTOCOL")
if model_protocol in {"json", "rest-json"}:
use_smithy = True
return use_smithy

@staticmethod
def _build_service_model(models_dir: str, endpoint_rules_dir: str, c2j_model_filename) -> ServiceModel:
"""Return a ServiceModel containing paths to the Service models: C2J model and endpoints (rules and tests).
:param models_dir (str): filepath (absolute or relative) to the dir with c2j models
Expand All @@ -153,8 +198,11 @@ def _build_service_model(endpoint_rules_dir: str, c2j_model_filename) -> Service
match = SERVICE_MODEL_FILENAME_PATTERN.match(c2j_model_filename)
service_id = match.group("service")

use_smithy = ModelUtils.is_smithy_enabled(service_id, models_dir, c2j_model_filename)

if os.path.exists(endpoint_rules_filepath) and os.path.exists(endpoint_tests_filepath):
return ServiceModel(service_id=service_id,
c2j_model=c2j_model_filename,
endpoint_rule_set=endpoint_rules_filename,
endpoint_tests=endpoint_tests_filename)
endpoint_tests=endpoint_tests_filename,
use_smithy=use_smithy)
6 changes: 4 additions & 2 deletions tools/scripts/codegen/protocol_tests_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED, ALL_COMPLETED

from codegen.legacy_c2j_cpp_gen import LegacyC2jCppGen
from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel
from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel, ModelUtils

PROTOCOL_TESTS_BASE_DIR = "tools/code-generation/protocol-tests"
PROTOCOL_TESTS_CLIENT_MODELS = PROTOCOL_TESTS_BASE_DIR + "/api-descriptions"
Expand Down Expand Up @@ -98,6 +98,7 @@ def _generate_test_clients(self, executor: ProcessPoolExecutor, max_workers: int

if len(failures):
return -1
return 0

def _collect_test_client_models(self) -> dict:
service_models = dict()
Expand All @@ -112,8 +113,9 @@ def _collect_test_client_models(self) -> dict:
if service_model_name in UNSUPPORTED_CLIENTS:
continue

use_smithy = ModelUtils.is_smithy_enabled(service_model_name, self.client_models_dir, filename)
service_models[service_model_name] = ServiceModel(service_model_name, filename,
PROTOCOL_TESTS_ENDPOINT_RULES, None)
PROTOCOL_TESTS_ENDPOINT_RULES, None, use_smithy)
return service_models

def _generate_tests(self):
Expand Down
3 changes: 3 additions & 0 deletions tools/scripts/codegen/smoke_tests_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def generate(self, clients_to_build: set):
self._copy_cpp_codegen_contents(os.path.abspath("tools/code-generation/smithy/codegen"),
"cpp-codegen-smoke-tests-plugin",
os.path.abspath("generated/smoke-tests"))
return 0
else:
return -1

def _generate_smoke_tests(self, smithy_services: List[str], smithy_c2j_data: str):
smithy_codegen_command = [
Expand Down
12 changes: 9 additions & 3 deletions tools/scripts/run_code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,24 @@ def main():

with ProcessPoolExecutor(max_workers=max_workers) as executor:
c2j_gen = LegacyC2jCppGen(args, model_utils.models_to_generate)
c2j_gen.generate(executor, max_workers, args)
if c2j_gen.generate(executor, max_workers, args) != 0:
print("ERROR: Failed to generate service client(s)!")
return -1

if args["generate_protocol_tests"]:
protocol_tests_generator = ProtocolTestsGen(args)
protocol_tests_generator.generate(executor, max_workers)
if protocol_tests_generator.generate(executor, max_workers) != 0:
print("ERROR: Failed to generate protocol test(s)!")
return -1

# generate code using smithy for all discoverable clients
# clients_to_build check is present because user can generate only defaults or partitions or protocol-tests
clients_to_build = model_utils.get_clients_to_build()
if args["generate_smoke_tests"] and clients_to_build:
smoke_tests_gen = SmokeTestsGen(args["debug"])
smoke_tests_gen.generate(clients_to_build)
if smoke_tests_gen.generate(clients_to_build) != 0:
print("ERROR: Failed to generate smoke test(s)!")
return -1

return 0

Expand Down

0 comments on commit 99c034e

Please sign in to comment.